[RFC] Support per-RPC timeouts in RPC layer. #32686
Labels
high priority
module: rpc
Related to RPC, distributed autograd, RRef, and distributed optimizer
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 Feature: Implement per-RPC timeouts.
This RFC is similar to the proposal by @xush6528 in #29402. We'd like to get rid of the overall global timeout that currently applies to all RPCs in PyTorch's RPC layer, and replace it with a per-RPC timeout that users can individually configure. This will allow for better reliability in our RPC framework and avoid spurious errors being reported back to users due to misconfigured timeouts, and increase user-friendliness by allowing them to tune their use of the RPC framework to the needs of their application.
Motivation
We currently have RPC timeouts implemented in out RPC layer, implemented in this PR: #28392 and #29601. Previous issues such as #29018 and #29402 have discussed ideas for extending these timeouts to be on a per-RPC basis.
This would be useful since it gives users the option to have more granular control over their RPCs and setting timeouts for them - if a user is doing something like creating an
RRef
to a module on another node and running an expensive forward pass or data processing operation on that node, they may want to specify a different timeout, than, for example, running a simple user defined function on a different node. As users build more customized applications on top of our rpc/model parallel primitives, having the same RPC timeout for every RPC call will not work.Supporting per-RPC timeouts will also help internal RPC messages, since they won't have to be bound to a particular timeout that can be set from user land. For example, internal messages such as those associated with the shutdown procedure are bound to the user-set timeout currently.
Pitch
The API design is the same as the one proposed by @xush6528 in #29402. We will have an optional timeout parameter exposed in all RPC APIs:
We can then propagate this timeout to the C++ layer where it can be passed into
RpcAgent::send()
, which creates thetorch::utils::Future
corresponding to the RPC. We can then use the existing scaffolding (such as what is implemented in #29601) to associate the future with a timeout, and mark it completed with an exception if it does time out.By default, if a timeout is not passed in to
RpcAgent::send()
, we will assume that no timeout is intended, so the future will never be marked as timed out. By default, we can do this for internal messages, and look into whether we want timeouts for internal messages as well - this will require coordination with the retry work going on in https://github.com/pytorch/pytorch/pull/32602/files.We can also get rid of the existing API
rpc.set_rpc_timeout(timedelta)
that sets the timeout for all RPCs.EDIT: we probably have to keep the API in for backcompat reasons.
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar
The text was updated successfully, but these errors were encountered: