[RFC] Distributed optimizer with TorchScript support #46883
Labels
module: rpc
Related to RPC, distributed autograd, RRef, and distributed optimizer
oncall: jit
Add this issue/PR to JIT oncall triage queue
Motivation
PyTorch provides a broad set of optimizers for training algorithms, and these have been used repeatedly as part of the python API. However, users often want to use multithreaded training instead of multiprocess training as it provides better resource utilization and efficiency in the context of large scale distributed training (e.g. Distributed Model Parallel) or any RPC-based training application). Users couldn’t do this with with distributed optimizer before because we need to get rid of the python Global Interpreter Lock (GIL) limitation to achieve this.
New DistributedOptimizer with TorchScript support
To make Distributed Optimizer work with TorchScript, we will refactor the existing optimizers to have a functional API and then let Distributed Optimizer use the functional API to have the TorchScript support.
Functional Optimizer
We have introduced the functional optimizer concept in
torch.optim
, and allow the computation and state management be separate. This makes it easier to let optimizers be TorchScript compatible, and unlocks the opportunity for distributed optimizer to use them in order to be GIL-free.DistributedOptimizer
In Distributed Optimizer, we maintain a separate set of functional optimizers that consists of state + computation, where the computation part use the shared functional API we introduced above.
It’s OK for distributed optimizer to stay in python. What we will do in distributed optimizer is that we would like to maintain the API that we expose to user, but we would like to use the functional optimizer. The trick here is that we maintain a map for optim_class that likes below:
In DistributedOptimizer initialization, we will just swap the OSS optimizer and use the functional optimizer we exposed to initialize the _LocalOptimizer (and compile them). A rough change like below:
Note we will need to refactor all optimizers in
torch.optim
to follow the functional API, and then register the functional optimizers one by one to make all of them available in Distributed Optimizer.Usage
The new distributed optimizer has exact same interface (APIs) as before, we exposed the same API as python, but do all the heavy lifting under the hood, upon DistributedOptimizer construction, it tries to find the corresponding functional optimizer match, and then construct the local optimizer, automatically turn optimizers in each worker into TorchScript to make it GIL free. Example usage is exactly the same with Python API:
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @jjlilley @osalpekar @jiayisuse @gmagogsfm @xush6528 @agolynski
The text was updated successfully, but these errors were encountered: