Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add size info to collective logs #100413

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 23 additions & 23 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,29 +291,19 @@ std::ostream& operator<<(
std::ostream& output,
const ProcessGroupNCCL::WorkNCCL& workNCCL) {
std::string workInfo;
if (workNCCL.outputs_) {
workInfo = c10::str(
"WorkNCCL(",
"SeqNum=",
workNCCL.seq_,
", OpType=",
opTypeToString(workNCCL.opType_),
", TensorShape=",
(*workNCCL.outputs_)[0].sizes(),
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
} else {
workInfo = c10::str(
"WorkNCCL(",
"SeqNum=",
workNCCL.seq_,
", OpType=",
opTypeToString(workNCCL.opType_),
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
}
workInfo = c10::str(
"WorkNCCL(",
"SeqNum=",
workNCCL.seq_,
", OpType=",
opTypeToString(workNCCL.opType_),
", NumelIn=",
workNCCL.numelIn_,
", NumelOut=",
workNCCL.numelOut_,
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
return output << workInfo;
}

Expand Down Expand Up @@ -353,6 +343,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
workStartTime_(w.workStartTime_),
seq_(w.seq_),
startTraceUpdated_(w.startTraceUpdated_),
numelIn_(w.numelIn_),
numelOut_(w.numelOut_),
store_(w.store_) {
exception_ = w.exception_;
}
Expand Down Expand Up @@ -1604,6 +1596,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
work->avoidRecordStreams_ = avoidRecordStreams_;
work->opTimeout_ = options_->timeout;
work->store_ = store_;
// Record size info for debug. We only record the size on the first device as
// multi-device per process is deprecated
work->numelIn_ = inputs[0].numel();
work->numelOut_ = outputs[0].numel();

if (!coalescing_state_) {
workEnqueue(work);
Expand Down Expand Up @@ -1750,6 +1746,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
work->store_ = store_;
}

// Record size info for debug. We only record the size on the first device as
// multi-device per process is deprecated
work->numelIn_ = work->numelOut_ = tensors[0].numel();

// Future only needs to be created and marked completed with outputs for
// recv(), but still create future for use cases such as profiling even for
// send().
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// This will be used by desync debug.
bool startTraceUpdated_{false};

// Record collective sizes for debug. We only record the size on the first
// device as multi-device per process is deprecated
size_t numelIn_ = -1;
size_t numelOut_ = -1;

// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors(
Expand Down