-
Notifications
You must be signed in to change notification settings - Fork 25k
Description
Apart from the deadlock issue mentioned in #26362, there could be two other reasons that causes join()
to occasionally throw timeout errors.
- There is a period of time that the sender has finished sending and the receiver has finished receiving, but the received message is still in the listenLoop() and hasn't been enqueued into the ThreadPool yet. It means that the following
sync
/join
code could exit before all communications finish.
pytorch/torch/csrc/distributed/rpc/process_group_agent.cpp
Lines 162 to 173 in 7e95439
void ProcessGroupAgent::sync() { | |
// Block until all processes wants to sync. This is necessary before acquiring | |
// the lock below, because other processes might not enter sync() until it | |
// gets some response from this RpcAgent. | |
pg_->barrier()->wait(); | |
// Wait until the all send works are done. | |
// NB: There might be additional send works inserted while waiting. | |
threadPool_.waitWorkComplete(); | |
// Use another barrier in case different RpcAgent handles different amounts of | |
// workloads. | |
pg_->barrier()->wait(); | |
} |
As a result, the SHUTDOWN
message could be sent before regular messages, leading to the timeout.
- Fix termination crash when having unresolved future #24074 and Add Python RRef as args and return value #25499 attempts to address the above problem by waiting for all futures to settle on each individual worker. However, this might not be sufficient. When we have many layers of nested rpc_async/remote calls, received messages could trigger more sends. It means even if one worker didn't see any send task in the thread pool and no unsettled futures at that time, it doesn't mean the messages it received later won't create new sends.
These might be one reason for the flakiness @rohan-varma saw earlier today in #26570.
Proposed Solution
How about using ProcessGroup::allreduce(MAX)
to globally check if there is any unfinished future on any worker, and use a while loop to ensure that SHUTDOWN
message is only sent after that? Something like:
while(true) {
auto t = torch::tensor({futures_.size()});
pg_->allreduce(t, ReduceOp::MAX)->wait();
if (t.storage().data<int64_t>()[0] == 0) {
break;
}
std::unique_lock<std::mutex> lock(futureMutex_);
futureCV_.wait(lock, [this] {
return futures_.empty();
});
lock.unlock();
}
However, future p2p RpcAgent
implementations does not have the luxury to use collective communication APIs. We might eventually need to build a more complex termination detection algorithm on top of RpcAgent
.
@xush6528 @satgera @aazzolini @rohan-varma @pietern Does this make sense?
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini