Skip to content

Commit

Permalink
Adding profiling capability to c++ ddp collective functions (#46471)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #46471

ghstack-source-id: 116018837

Test Plan:
Added unit tests:

 buck test mode/dev-nosan caffe2/test/distributed:distributed_gloo_fork
 buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork

Reviewed By: rohan-varma

Differential Revision: D23948397

fbshipit-source-id: 6d93a370aff26bf96c39e5d78a2492c5142a9156
  • Loading branch information
mrzzd authored and facebook-github-bot committed Nov 6, 2020
1 parent 1aeefcd commit 160db3d
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 66 deletions.
3 changes: 1 addition & 2 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -1086,8 +1086,7 @@ that adds a prefix to each key inserted to the store.
&::c10d::ProcessGroup::Work::wait,
py::arg("timeout") = kNoTimeout,
py::call_guard<py::gil_scoped_release>())
.def(
"get_future",
.def("get_future",
[](::c10d::ProcessGroup::Work& work)
-> std::shared_ptr<jit::PythonFutureWrapper> {
return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
Expand Down
26 changes: 23 additions & 3 deletions torch/lib/c10d/ProcessGroup.cpp
@@ -1,4 +1,6 @@
#include <c10d/ProcessGroup.hpp>
#include <ATen/ThreadLocalState.h>


#include <c10/util/Logging.h>

Expand Down Expand Up @@ -51,10 +53,20 @@ bool isP2POp(OpType opType) {
opType == OpType::RECVANYSOURCE;
}

ProcessGroup::Work::Work() : rank_(-1), opType_(OpType::UNKNOWN) {}

ProcessGroup::Work::Work(int rank, OpType opType)
: rank_(rank), opType_(opType) {}
ProcessGroup::Work::Work(int rank, OpType opType, const char* profilingTitle)
: rank_(rank), opType_(opType) {
if (profilingTitle != nullptr) {
auto recordingFunction = std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
if (recordingFunction->active) {
recordingFunction->before(profilingTitle, {});
std::function<void()> end_handler = [this, recordingFunction]() {
recordingFunction->end();
};
recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
}
}
}

OpType ProcessGroup::Work::retrieveOpType() {
return opType_;
Expand Down Expand Up @@ -123,6 +135,10 @@ void ProcessGroup::Work::finish(std::exception_ptr exception) {
std::unique_lock<std::mutex> lock(mutex_);
completed_ = true;
exception_ = exception;
if (recordFunctionEndCallback_) {
recordFunctionEndCallback_();
recordFunctionEndCallback_ = nullptr;
}
lock.unlock();
cv_.notify_all();
}
Expand All @@ -131,6 +147,10 @@ void ProcessGroup::Work::finishAndThrow(std::exception_ptr exception) {
std::unique_lock<std::mutex> lock(mutex_);
completed_ = true;
exception_ = exception;
if (recordFunctionEndCallback_) {
recordFunctionEndCallback_();
recordFunctionEndCallback_ = nullptr;
}
if (exception_) {
std::rethrow_exception(exception_);
}
Expand Down
8 changes: 5 additions & 3 deletions torch/lib/c10d/ProcessGroup.hpp
Expand Up @@ -77,9 +77,7 @@ class ProcessGroup {
// this will be bound using pybind.
class Work {
public:
Work();

Work(int rank, OpType opType);
Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr);

virtual ~Work();

Expand Down Expand Up @@ -156,6 +154,10 @@ class ProcessGroup {

// Operation type that this work object refers to.
OpType opType_;

// When profiling, the callback to record end of operation event. This
// callback needs to be called when collective operation is complete.
std::function<void()> recordFunctionEndCallback_;
};

explicit ProcessGroup(int rank, int size);
Expand Down
27 changes: 18 additions & 9 deletions torch/lib/c10d/ProcessGroupGloo.cpp
Expand Up @@ -677,7 +677,8 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork {
int rootRank,
int rootTensor,
uint32_t tag)
: context(context),
: ProcessGroupGloo::AsyncWork("gloo:broadcast"),
context(context),
inputs(inputs),
rootRank(rootRank),
rootTensor(rootTensor),
Expand Down Expand Up @@ -823,7 +824,8 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork {
std::vector<at::Tensor>& inputs,
ReduceOp reduceOp,
uint32_t tag)
: context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {}
: ProcessGroupGloo::AsyncWork("gloo:all_reduce"),
context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {}

std::shared_ptr<gloo::Context> context;
std::vector<at::Tensor> inputs;
Expand Down Expand Up @@ -1431,7 +1433,8 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork {
int rootTensor,
ReduceOp reduceOp,
uint32_t tag)
: context(context),
: ProcessGroupGloo::AsyncWork("gloo:reduce"),
context(context),
inputs(inputs),
rootRank(rootRank),
rootTensor(rootTensor),
Expand Down Expand Up @@ -1595,7 +1598,8 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
uint32_t tag)
: context(context), outputs(outputs), inputs(inputs), tag(tag) {}
: ProcessGroupGloo::AsyncWork("gloo:all_gather"),
context(context), outputs(outputs), inputs(inputs), tag(tag) {}

std::shared_ptr<gloo::Context> context;
std::vector<std::vector<at::Tensor>> outputs;
Expand Down Expand Up @@ -1792,7 +1796,8 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
std::vector<std::vector<at::Tensor>>& output_lists,
std::vector<at::Tensor>& input_list,
uint32_t tag)
: context(context),
: ProcessGroupGloo::AsyncWork("gloo:all_gather"),
context(context),
output_lists(output_lists),
input_list(input_list),
tag(tag) {}
Expand Down Expand Up @@ -1921,7 +1926,8 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
std::vector<at::Tensor>& inputs,
int root,
uint32_t tag)
: context(context),
: ProcessGroupGloo::AsyncWork("gloo:gather"),
context(context),
outputs(outputs),
inputs(inputs),
root(root),
Expand Down Expand Up @@ -2125,7 +2131,8 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork {
std::vector<std::vector<at::Tensor>>& inputs,
int root,
uint32_t tag)
: context(context),
: ProcessGroupGloo::AsyncWork("gloo:scatter"),
context(context),
outputs(outputs),
inputs(inputs),
root(root),
Expand Down Expand Up @@ -2319,7 +2326,8 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork {
std::vector<int64_t>& outputCounts,
std::vector<int64_t>& inputCounts,
uint32_t tag)
: context(context),
: ProcessGroupGloo::AsyncWork("gloo:all_to_all"),
context(context),
outputTensor(outputTensor),
inputTensor(inputTensor),
outputCounts(std::move(outputCounts)),
Expand Down Expand Up @@ -2576,7 +2584,8 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork {
const std::shared_ptr<gloo::Context>& context,
std::vector<std::weak_ptr<AsyncWork>> priorWork,
uint32_t tag)
: context(context), priorWork(std::move(priorWork)), tag(tag) {}
: ProcessGroupGloo::AsyncWork("gloo:barrier"),
context(context), priorWork(std::move(priorWork)), tag(tag) {}

std::shared_ptr<gloo::Context> context;
std::vector<std::weak_ptr<AsyncWork>> priorWork;
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/ProcessGroupGloo.hpp
Expand Up @@ -68,6 +68,8 @@ class ProcessGroupGloo : public ProcessGroup {
//
class AsyncWork : public ProcessGroup::Work {
public:
AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {}

static void execute(std::shared_ptr<AsyncWork> work) {
std::exception_ptr eptr;
try {
Expand Down
54 changes: 39 additions & 15 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -240,8 +240,9 @@ std::ostream& operator<<(
ProcessGroupNCCL::WorkNCCL::WorkNCCL(
const std::vector<at::Device>& devices,
int rank,
OpType opType)
: Work(rank, opType),
OpType opType,
const char* profilingTitle)
: Work(rank, opType, profilingTitle),
devices_(devices),
workStartTime_(std::chrono::steady_clock::now()) {
// Creates the CUDA event wrappers
Expand Down Expand Up @@ -986,8 +987,9 @@ std::vector<at::Tensor> flatten_for_scatter_gather(
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
std::vector<at::Device> devices,
int rank,
OpType opType) {
return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices, rank, opType);
OpType opType,
const char* profilingTitle) {
return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices, rank, opType, profilingTitle);
}

std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result() {
Expand Down Expand Up @@ -1031,7 +1033,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType) {
OpType opType,
const char* profilingTitle) {
const auto devices = getDeviceList(inputs);
const auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices, opType);
Expand All @@ -1040,13 +1043,25 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
syncStreams(devices, ncclEvents_[key], ncclStreams_[key]);

// Work itself will create the CUDA events on all GPUs of tensors
auto work = initWork(devices, rank_, opType);
bool can_profile = outputs.size() == 1;
auto work = initWork(devices, rank_, opType, can_profile ? profilingTitle : nullptr);

// Store references to outputs and futureNCCLCallbackStream to be used by
// WorkNCCL::getFuture.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_;

if (work->recordFunctionEndCallback_) {
// recordFunctionEndCallback_ is normally called in fininsh() function by
// base class, but since finish is not called by WorkNCCL, we schedule this
// function to be run when work is done.
// Note when can_profile is false, profilingTitle is not provided and so,
// recordFunctionEndCallback_ is not set.
work->getFuture()->addCallback(std::move(work->recordFunctionEndCallback_));
}



at::cuda::OptionalCUDAGuard gpuGuard;

pre(ncclStreams_[key]);
Expand Down Expand Up @@ -1175,14 +1190,16 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
Fn fn,
OpType opType) {
OpType opType,
const char* profilingTitle) {
return collective(
inputs,
outputs,
fn,
[](std::vector<at::cuda::CUDAStream>&) {},
[](std::vector<at::cuda::CUDAStream>&) {},
opType);
opType,
profilingTitle);
}

template <typename Fn>
Expand Down Expand Up @@ -1221,7 +1238,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
comm,
stream.stream());
},
OpType::ALLREDUCE);
OpType::ALLREDUCE,
"nccl:all_reduce");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce_coalesced(
Expand Down Expand Up @@ -1252,7 +1270,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::broadcast(
comm,
stream.stream());
},
OpType::BROADCAST);
OpType::BROADCAST,
"nccl:broadcast");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
Expand All @@ -1278,7 +1297,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
comm,
stream.stream());
},
OpType::REDUCE);
OpType::REDUCE,
"nccl:reduce");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
Expand Down Expand Up @@ -1322,7 +1342,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
}
}
},
OpType::ALLGATHER);
OpType::ALLGATHER,
"nccl:all_gather");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather_coalesced(
Expand Down Expand Up @@ -1375,7 +1396,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
}
},
[&](std::vector<at::cuda::CUDAStream>& ncclStreams) {},
OpType::REDUCE_SCATTER);
OpType::REDUCE_SCATTER,
"nccl:reduce_scatter");
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::barrier(
Expand Down Expand Up @@ -1448,7 +1470,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
stream);
return ncclSuccess;
},
OpType::ALLTOALL_BASE);
OpType::ALLTOALL_BASE,
"nccl:all_to_all");
} else {
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
Expand Down Expand Up @@ -1484,7 +1507,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
comm,
stream.stream());
},
OpType::ALLTOALL_BASE);
OpType::ALLTOALL_BASE,
"nccl:all_to_all");
}
}

Expand Down
11 changes: 7 additions & 4 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -68,7 +68,7 @@ class ProcessGroupNCCL : public ProcessGroup {
public std::enable_shared_from_this<WorkNCCL> {
public:
// Constructor takes a list of CUDA devices
WorkNCCL(const std::vector<at::Device>& devices, int rank, OpType opType);
WorkNCCL(const std::vector<at::Device>& devices, int rank, OpType opType, const char* profilingTitle = nullptr);
// Copy constructor doing partial copy without outputs_. Cleanup thread
// monitors and removes finished works. However it will deadlock when
// destructs outputs_ tensors who are view tensors in autograd graph.
Expand Down Expand Up @@ -518,7 +518,8 @@ class ProcessGroupNCCL : public ProcessGroup {
virtual std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
OpType opType);
OpType opType,
const char* profilingTitle=nullptr);

private:
// Helper that encapsulates work shared across all collective communication
Expand All @@ -532,15 +533,17 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
OpType opType);
OpType opType,
const char* profilingTitle = nullptr);
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType);
OpType opType,
const char* profilingTitle = nullptr);

// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
Expand Down
6 changes: 4 additions & 2 deletions torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp
Expand Up @@ -59,7 +59,8 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
c10d::OpType opType) override {
c10d::OpType opType,
const char* profilingTitle) override {
return std::make_shared<WorkNCCLSimulateErrors>(
devices, simulate_error_, rank, opType);
}
Expand Down Expand Up @@ -115,7 +116,8 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
c10d::OpType opType) override {
c10d::OpType opType,
const char* profilingTitle) override {
return std::make_shared<WorkNCCLTimedoutErrors>(
devices, set_timedout_error_, rank, opType);
}
Expand Down

0 comments on commit 160db3d

Please sign in to comment.