-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
comm.h
104 lines (84 loc) · 3.43 KB
/
comm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#pragma once
#include <ATen/ATen.h>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
namespace c10d {
// Broadcast many tensors to all processes in the process group.
void broadcast_coalesced(
std::shared_ptr<c10d::ProcessGroup> process_group,
at::TensorList tensors,
size_t buffer_size,
int rank = 0);
// This class passes bucket contents tensor (for multiple replicas) to
// DDP communication hook.
// Optionally in the future this can be enhanced with parameter to bucket
// mappings as well.
class GradBucket {
public:
explicit GradBucket(const std::vector<at::Tensor>& tensors)
: tensors_(tensors) {}
// Each tensor in the list that getTensors returns refers to the replica on
// each device. There will be multiple replicas only in the case of single
// process multiple device mode. In the single process single device mode,
// this list would consist of only a single tensor.
const std::vector<at::Tensor>& getTensors() const {
return tensors_;
}
private:
std::vector<at::Tensor> tensors_;
};
// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing 1) `runHook` method that communicates gradients
// asynchronously, and 2) `parseHookResult` method that converts the hook result
// into a tensor vector.
class TORCH_PYTHON_API CommHookInterface {
public:
virtual ~CommHookInterface() {}
// Passes the input grad bucket to the registered communication hook.
// Once the tensors in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<torch::jit::Future> runHook(
GradBucket& bucket) = 0;
// Returns the resulting tensors once the communication hook result is ready.
// The resulting tensors will then be copied to the grads of individual
// parameters.
virtual std::vector<at::Tensor> parseHookResult(
const c10::IValue& result) = 0;
};
class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
public:
// Takes a state and a callable hook. The inputs are Python objects.
// The state is passed to the hook in runHook method, and it can be used to
// maintain and update any state information during the execution of the hook.
// The hook performs user-specified processing and returns a future indicating
// asychronous communication of gradients.
PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)) {}
~PythonCommHook() override;
c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override;
std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override;
private:
// Only needed for stateful communication.
py::object state_;
py::object hook_;
};
// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
template <typename T>
class TORCH_API CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(T& state) : state_(state) {}
virtual ~CppCommHookInterface() {}
std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override {
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList");
if (result.isTensor()) {
return {result.toTensor()};
}
return result.toTensorVector();
}
protected:
T state_; // Not owned.
};
} // namespace c10d