Skip to content

Commit

Permalink
Add size info to collective logs (pytorch#100413)
Browse files Browse the repository at this point in the history
Previous timeout log does not print size info. Making it hard to debug hang caused by message size mismatch.

(Reason is that when copying `WorkNCCL` object during work enqueue, we don't copy `outputs_` due to reference concern, hence `output.size()` is never triggered.)

This PR logs sizes using separate fields, hence not relying on `outputs_`.

New timeout log:
```
[Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=_ALLGATHER_BASE, NumelIn=209715200, NumelOut=1677721600, Timeout(ms)=10000) ran for 10957 milliseconds before timing out.
```
Pull Request resolved: pytorch#100413
Approved by: https://github.com/kumpera
  • Loading branch information
kwen2501 authored and shaoyf42 committed Jun 1, 2023
1 parent d47d054 commit 8c42365
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
46 changes: 23 additions & 23 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,29 +291,19 @@ std::ostream& operator<<(
std::ostream& output,
const ProcessGroupNCCL::WorkNCCL& workNCCL) {
std::string workInfo;
if (workNCCL.outputs_) {
workInfo = c10::str(
"WorkNCCL(",
"SeqNum=",
workNCCL.seq_,
", OpType=",
opTypeToString(workNCCL.opType_),
", TensorShape=",
(*workNCCL.outputs_)[0].sizes(),
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
} else {
workInfo = c10::str(
"WorkNCCL(",
"SeqNum=",
workNCCL.seq_,
", OpType=",
opTypeToString(workNCCL.opType_),
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
}
workInfo = c10::str(
"WorkNCCL(",
"SeqNum=",
workNCCL.seq_,
", OpType=",
opTypeToString(workNCCL.opType_),
", NumelIn=",
workNCCL.numelIn_,
", NumelOut=",
workNCCL.numelOut_,
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
return output << workInfo;
}

Expand Down Expand Up @@ -353,6 +343,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
workStartTime_(w.workStartTime_),
seq_(w.seq_),
startTraceUpdated_(w.startTraceUpdated_),
numelIn_(w.numelIn_),
numelOut_(w.numelOut_),
store_(w.store_) {
exception_ = w.exception_;
}
Expand Down Expand Up @@ -1604,6 +1596,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
work->avoidRecordStreams_ = avoidRecordStreams_;
work->opTimeout_ = options_->timeout;
work->store_ = store_;
// Record size info for debug. We only record the size on the first device as
// multi-device per process is deprecated
work->numelIn_ = inputs[0].numel();
work->numelOut_ = outputs[0].numel();

if (!coalescing_state_) {
workEnqueue(work);
Expand Down Expand Up @@ -1750,6 +1746,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
work->store_ = store_;
}

// Record size info for debug. We only record the size on the first device as
// multi-device per process is deprecated
work->numelIn_ = work->numelOut_ = tensors[0].numel();

// Future only needs to be created and marked completed with outputs for
// recv(), but still create future for use cases such as profiling even for
// send().
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// This will be used by desync debug.
bool startTraceUpdated_{false};

// Record collective sizes for debug. We only record the size on the first
// device as multi-device per process is deprecated
size_t numelIn_ = -1;
size_t numelOut_ = -1;

// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors(
Expand Down

0 comments on commit 8c42365

Please sign in to comment.