Skip to content

ProcessGroupAgent failed to join with long async forward chain #26944

@mrshenli

Description

@mrshenli

Apart from the deadlock issue mentioned in #26362, there could be two other reasons that causes join() to occasionally throw timeout errors.

  1. There is a period of time that the sender has finished sending and the receiver has finished receiving, but the received message is still in the listenLoop() and hasn't been enqueued into the ThreadPool yet. It means that the following sync/join code could exit before all communications finish.

void ProcessGroupAgent::sync() {
// Block until all processes wants to sync. This is necessary before acquiring
// the lock below, because other processes might not enter sync() until it
// gets some response from this RpcAgent.
pg_->barrier()->wait();
// Wait until the all send works are done.
// NB: There might be additional send works inserted while waiting.
threadPool_.waitWorkComplete();
// Use another barrier in case different RpcAgent handles different amounts of
// workloads.
pg_->barrier()->wait();
}

As a result, the SHUTDOWN message could be sent before regular messages, leading to the timeout.

  1. Fix termination crash when having unresolved future #24074 and Add Python RRef as args and return value #25499 attempts to address the above problem by waiting for all futures to settle on each individual worker. However, this might not be sufficient. When we have many layers of nested rpc_async/remote calls, received messages could trigger more sends. It means even if one worker didn't see any send task in the thread pool and no unsettled futures at that time, it doesn't mean the messages it received later won't create new sends.

These might be one reason for the flakiness @rohan-varma saw earlier today in #26570.

Proposed Solution

How about using ProcessGroup::allreduce(MAX) to globally check if there is any unfinished future on any worker, and use a while loop to ensure that SHUTDOWN message is only sent after that? Something like:

while(true) {
  auto t = torch::tensor({futures_.size()});
  pg_->allreduce(t, ReduceOp::MAX)->wait();
  if (t.storage().data<int64_t>()[0] == 0) {
    break;
  }
  std::unique_lock<std::mutex> lock(futureMutex_);
  futureCV_.wait(lock, [this] {
    return futures_.empty();
  });
  lock.unlock();
}

However, future p2p RpcAgent implementations does not have the luxury to use collective communication APIs. We might eventually need to build a more complex termination detection algorithm on top of RpcAgent.

@xush6528 @satgera @aazzolini @rohan-varma @pietern Does this make sense?

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini

Metadata

Metadata

Assignees

Labels

module: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizeroncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions