-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
comm.cpp
138 lines (114 loc) · 4.41 KB
/
comm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <torch/csrc/distributed/c10d/comm.h>
#include <deque>
#include <ATen/core/functional.h>
#include <torch/csrc/distributed/c10d/reducer.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/tensor_flatten.h>
namespace c10d {
namespace {
class BroadcastWork {
public:
BroadcastWork(
const std::shared_ptr<c10d::ProcessGroup>& process_group,
std::vector<at::Tensor> bucket_tensors,
int root_rank = 0)
: bucket_tensors_(std::move(bucket_tensors)),
flat_tensor_({torch::utils::flatten_dense_tensors(bucket_tensors_)}) {
BroadcastOptions broadcastOptions;
broadcastOptions.rootRank = root_rank;
work_ = process_group->broadcast(flat_tensor_, broadcastOptions);
}
void finish() {
work_->wait();
// Copy the output of the broadcast operation back.
auto output_tensors = torch::utils::unflatten_dense_tensors(
flat_tensor_.front(), bucket_tensors_);
TORCH_INTERNAL_ASSERT(output_tensors.size() == bucket_tensors_.size());
for (size_t i = 0; i < output_tensors.size(); i++) {
bucket_tensors_[i].copy_(output_tensors[i], /*non_blocking=*/true);
}
}
protected:
// The list of tensors to broadcast. They are guaranteed to be
// placed on the same device and have the same dtype.
std::vector<at::Tensor> bucket_tensors_;
// The vector with a single flattened tensor containing the contents
// of the tensors in bucket_tensors_. It must be stored in a vector
// because c10d::ProcessGroup::broadcast takes a vector argument.
std::vector<at::Tensor> flat_tensor_;
// The broadcast work that is kicked off upon construction.
std::shared_ptr<c10d::ProcessGroup::Work> work_;
};
} // namespace
// Broadcast many tensors to all processes in the process group.
void broadcast_coalesced(
std::shared_ptr<c10d::ProcessGroup> process_group,
at::TensorList tensors,
size_t buffer_size,
int rank) {
// Coalesce tensors into buckets taking into account the maximum buffer size.
// This routine is multi-device aware, so the tensors can be split across
// multiple devices and can contain a mix of CPU and CUDA tensors.
const auto buckets =
compute_bucket_assignment_by_size(tensors.vec(), {buffer_size});
// Returns tensor at specified index in input tensor list.
const auto lookup = [&tensors](size_t index) { return tensors[index]; };
// We maintain a maximum of 2 in flight broadcast operations to avoid
// allocating too much memory (in case the specified tensors are very large).
std::deque<BroadcastWork> in_flight;
constexpr auto max_in_flight = 2;
for (const auto& bucket : buckets) {
if (in_flight.size() >= max_in_flight) {
in_flight.front().finish();
in_flight.pop_front();
}
in_flight.emplace_back(process_group, c10::fmap(bucket, lookup), rank);
}
while (!in_flight.empty()) {
in_flight.front().finish();
in_flight.pop_front();
}
}
PythonCommHook::~PythonCommHook() {
py::gil_scoped_acquire ag;
state_.dec_ref();
hook_.dec_ref();
// Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor
// to decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
state_.ptr() = nullptr;
hook_.ptr() = nullptr;
}
c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
GradBucket& bucket) {
py::gil_scoped_acquire acquire;
py::object py_fut = hook_(state_, bucket);
try {
return py_fut.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>()->fut;
} catch (const py::cast_error& e) {
auto type = py_fut.get_type();
auto errMsg = c10::str(
e.what(),
". DDP communication hook's callback must return a "
"torch.futures.Future or torch._C.Future object, but got ",
type.attr("__module__").cast<std::string>(),
".",
type.attr("__qualname__").cast<std::string>());
throw std::runtime_error(errMsg);
}
}
std::vector<at::Tensor> PythonCommHook::parseHookResult(
const c10::IValue& result) {
TORCH_INTERNAL_ASSERT(
result.isPyObject() || result.isTensorList(),
"expected the hook result is either a PyObject or TensorList");
if (result.isPyObject()) {
py::gil_scoped_acquire ag;
py::object obj = torch::jit::toPyObject(result);
auto value = torch::jit::toIValue(
obj, c10::ListType::create(c10::TensorType::get()));
return value.toTensorVector();
}
return result.toTensorVector();
}
} // namespace c10d