-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
process_group_agent.h
287 lines (250 loc) · 10.8 KB
/
process_group_agent.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#pragma once
#include <c10/core/thread_pool.h>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/rpc/request_callback.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <atomic>
#include <thread>
namespace torch {
namespace distributed {
namespace rpc {
constexpr auto kDefaultNumSendRecvThreads = 4;
struct ProcessGroupRpcBackendOptions : public RpcBackendOptions {
ProcessGroupRpcBackendOptions(
int num_send_recv_threads,
float rpc_timeout,
std::string init_method)
: RpcBackendOptions(rpc_timeout, init_method),
numSendRecvThreads(num_send_recv_threads) {
TORCH_CHECK(
num_send_recv_threads > 0,
"Cannot create ProcessGroup RPC backend with ",
num_send_recv_threads,
" threads in the thread-pool.");
}
int numSendRecvThreads;
};
// SendWork and RecvWork will be put into a task queue, and later picked up by
// worker threads from the same ThreadPool.
struct SendWork {
SendWork(const WorkerInfo& to, Message&& message)
: to_(to), message_(message) {}
const WorkerInfo& to_;
Message message_;
};
// SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is
// to allow us to run serialization/deserialization in the worker threads.
struct RecvWork {
RecvWork(
const WorkerInfo& from,
MessageType type,
int64_t id,
torch::Tensor&& payload)
: from_(from), type_(type), id_(id), payload_(payload) {}
const WorkerInfo& from_;
const MessageType type_;
const int64_t id_;
torch::Tensor payload_;
};
class TORCH_API ProcessGroupAgent : public RpcAgent {
public:
ProcessGroupAgent(
std::string workerName,
std::shared_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb);
const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
const WorkerInfo& getWorkerInfo(worker_id_t id) const override;
std::vector<WorkerInfo> getWorkerInfos() const override;
void join() override;
void sync() override;
void startImpl() override;
void shutdownImpl() override;
~ProcessGroupAgent() override;
std::unordered_map<std::string, std::string> getMetrics() override;
protected:
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
// consume SendWork from the queue and send it out.
std::shared_ptr<FutureMessage> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);
// Bypass handleSend() logic and send a message to self rank
virtual void sendToSelf(Message&& message);
private:
class MessageCounter {
public:
explicit MessageCounter(int worldSize);
void increment(int dst);
std::vector<int64_t> snapshot();
private:
std::vector<int64_t> counters_;
std::mutex mutex_;
};
// TODO: this class should inherit from a MetricsTracker, and can be extended
// to track num_sends, recvs, average size of messages, etc.
struct AverageMetricsTracker {
std::string key_;
uint64_t currentSum_;
uint64_t currentCount_;
explicit AverageMetricsTracker(
std::string key,
uint64_t currentSum = 0,
uint64_t currentCount = 0);
void addData(uint64_t dataPoint);
double computeAverage();
};
// The FutureInfo struct stores a shared_ptr to the future, as well as
// additional information to manage timeouts and destination information,
// which is needed for termination detection.
struct FutureInfo {
std::shared_ptr<FutureMessage> future_;
steady_clock_time_point endTime_;
int dstRank_;
std::chrono::milliseconds timeout_;
FutureInfo(
const std::shared_ptr<FutureMessage>& future,
const steady_clock_time_point& endTime,
int dstRank,
const std::chrono::milliseconds timeout)
: future_(future),
endTime_(endTime),
dstRank_(dstRank),
timeout_(timeout) {}
FutureInfo() = delete;
};
void collectNames();
// handle a SendWork request. This serializes the payload inside the work
// object, and sends the message to the receiver using the underlying
// ProcessGroup.
void handleSend(const SendWork& work);
// put RecvWork into a queue and notify the worker thread
void enqueueRecv(RecvWork work);
// handle a RecvWork request. Return true if we should increment recvCounts,
// false if not (i.e. if the RPC timed out and we are getting a result after
// the timeout). This ensures that the messages accounted for in
// hasPendingMessage() are tallied properly during a graceful shutdown.
bool handleRecv(RecvWork& work);
// Loop that receives and processes messages
void listenLoopInternal();
// Calls listenLoopInternal and handles errors such as timeouts on the
// process group.
void listenLoop();
// exception_pointer correspnding to an exception raised in listenLoop (if
// there is one), and lock to guard access.
std::exception_ptr listenLoopException_;
std::mutex listenLoopExceptionMutex_;
// poll for timed out RPCs
void pollTimedOutRPCs();
// process timed out futures
const std::vector<FutureInfo> processTimedOutFutures();
// compute the remaining time for an RPC, given its end time.
const std::chrono::milliseconds getRPCRemainingTime(
const std::chrono::milliseconds& rpcEndTime) const;
// a helper function to mark a future in the futures_ map with a message. The
// future is marked with the passed in message, and then removed from the
// futures_ map. It is also removed from the futureTimeouts_ map since these
// maps are kept in sync.
void markFutureWithError(Message& message);
void markFutureWithError(int64_t id, std::string errorMsg);
// Note [Termination Detection]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//
// RpcAgent implementations must properly detect termination. Otherwise, it
// would result in message loss, RRef leak, or process hang. It is not
// sufficient to just wait for the thread pool to finish processing all tasks
// after all processes hit the join function. There could be nested rpc/remote
// calls, meaning that an empty task queue in the thread pool does not mean
// there will be no tasks added in the future. Moreover, in the listenLoop,
// there is a period of time when the message has been received but not yet
// inserted into the thread pool, which also suggests that the empty task
// queue is not a good indicator for termination.
//
// To detect termination, each ProcessGroupAgent maintains a sent message
// counter and a received message counter. The sent message counter is
// incremented whenever a message is sent, and the receive message counter is
// only incremented when a message has been processed. During termination, all
// ProcessGroupAgent instances run an allgather to collect counters from all
// peers, which means that all agents will have a consistent view on the
// message count snapshot. They would only terminate if all sent/received
// message counters match.
bool hasPendingMessage();
int64_t nextId() {
return ++nextId_;
}
std::shared_ptr<c10d::ProcessGroup> pg_;
// worker name -> rank
std::unordered_map<std::string, worker_id_t> nameMap_;
std::vector<WorkerInfo> allWorkerInfo_;
// record the number of messages sent to and received from each peer. The recv
// counter is only marked after the message is processed. Join uses allgather
// to collect all counts from all peers, uses these counters to detect global
// termination and only exit when all sent messages are processed.
MessageCounter sendCounts_;
MessageCounter recvCounts_;
std::atomic<int64_t> nextId_;
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
// when using the same tag.
std::vector<std::mutex> sendMutexes_;
std::thread listenerThread_;
// A thread to poll existing futures and check for timed out ones.
std::thread futureTimeoutThread_;
// Lock and shared ptr to currently pending work, set in listenloop() and
// interruptible in shutdown().
std::mutex recvWorkMutex_;
c10::intrusive_ptr<c10d::ProcessGroup::Work> recvWork_;
// Map of dst rank to current oustanding sends that we are waiting on. In the
// case of a call to ::shutdown() while we are still waiting on these sends,
// the pending sends contained in this map will be aborted, allowing the
// waiting thread to be unblocked.
std::unordered_map<
worker_id_t,
std::set<c10::intrusive_ptr<c10d::ProcessGroup::Work>>>
currentPendingSends_;
// Lock to serialize access to the above map.
std::mutex pendingSendMutex_;
// A threadPool that processing both SendWork and RecvWork. There are two
// motivations for adding a ThreadPool:
// (1) RPC serialization/deserialization and processing can be expensive,
// hence using multiple threads to speed it up.
// (2) The current RPC API does not support asynchronous UDFs, e.g., UDFs can
// not yield in the middle of execution to wait for IO, and resume the IO
// is done. This would result in deadlocks when we have nested RPC calls.
// NB: Ideally, this should be addressed by supporting asynchronous UDF.
// This is just a temporary solution for (2).
ThreadPool threadPool_;
// Atomic to indicate whether the timeout thread is enabled.
std::atomic<bool> timeoutThreadEnabled_;
// Mapping of request id to FutureInfo struct.
std::unordered_map<int64_t, FutureInfo> futures_;
// A map to keep track of when futures time out. The map is keyed by the time
// (millisecond level precision) the future will expire. This is so that timed
// out futures can be efficiently cleaned up, and we can quickly exit if we
// find a future that has not timed out. The values correspond to an
// unordered_set of future ids that started at that time. This map must be
// kept in sync with the above futures_ map.
std::map<steady_clock_time_point, std::unordered_set<int64_t>>
futureTimeouts_;
mutable std::mutex futureMutex_;
mutable std::condition_variable futureCV_;
// CV to wake up watchdog thread that watches for timed out futures.
std::condition_variable futureTimeoutCV_;
// Metrics tracked for ProcessGroupAgent.
enum ProcessGroupAgentMetrics {
GIL_WAIT_TIME = 0,
N_METRICS,
};
std::mutex metricsMutex_;
std::vector<std::unique_ptr<AverageMetricsTracker>> metrics_;
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
std::atomic<int32_t> clientActiveCalls_{0};
std::atomic<int32_t> serverActiveCalls_{0};
std::atomic<int32_t> serverActiveAsyncCalls_{0};
};
} // namespace rpc
} // namespace distributed
} // namespace torch