Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Distributed optimizer with TorchScript support #46883

Closed
wanchaol opened this issue Oct 26, 2020 · 0 comments
Closed

[RFC] Distributed optimizer with TorchScript support #46883

wanchaol opened this issue Oct 26, 2020 · 0 comments
Labels
module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@wanchaol
Copy link
Contributor

wanchaol commented Oct 26, 2020

Motivation

PyTorch provides a broad set of optimizers for training algorithms, and these have been used repeatedly as part of the python API. However, users often want to use multithreaded training instead of multiprocess training as it provides better resource utilization and efficiency in the context of large scale distributed training (e.g. Distributed Model Parallel) or any RPC-based training application). Users couldn’t do this with with distributed optimizer before because we need to get rid of the python Global Interpreter Lock (GIL) limitation to achieve this.

New DistributedOptimizer with TorchScript support

To make Distributed Optimizer work with TorchScript, we will refactor the existing optimizers to have a functional API and then let Distributed Optimizer use the functional API to have the TorchScript support.

Functional Optimizer

We have introduced the functional optimizer concept in torch.optim, and allow the computation and state management be separate. This makes it easier to let optimizers be TorchScript compatible, and unlocks the opportunity for distributed optimizer to use them in order to be GIL-free.

DistributedOptimizer

In Distributed Optimizer, we maintain a separate set of functional optimizers that consists of state + computation, where the computation part use the shared functional API we introduced above.

It’s OK for distributed optimizer to stay in python. What we will do in distributed optimizer is that we would like to maintain the API that we expose to user, but we would like to use the functional optimizer. The trick here is that we maintain a map for optim_class that likes below:

{
    torch.optim.Adagrad: torch.distributed.optim.FunctionalAdagrad,
    torch.optim.SGD: torch.distributed.optim.FunctionalSGD,
    ...
}

In DistributedOptimizer initialization, we will just swap the OSS optimizer and use the functional optimizer we exposed to initialize the _LocalOptimizer (and compile them). A rough change like below:

class DistributedOptimizer:
    def __init__(self, optimizer_class, params_rref, *args, **kwargs):

*        functional_optimizer_class = optim_table.get(optimizer_class, None)
        if functional_optimizer_class is None:
            raise Warning("Optimizer " + str(optimizer_class) + " not supported")


        # throw warning/logs switching from oss optimizer to functional optimizer*
        per_worker_params_rref = defaultdict(list)
        for param in params_rref:
            per_worker_params_rref[param.owner()].append(param)

        remote_optim_futs = []
        for worker, param_rrefs in per_worker_params_rref.items():
            remote_optim_rref_fut = rpc.rpc_async(
                worker,
                _new_local_optimizer,
                args=(*functional_optimizer_class*, param_rrefs) + args,
                kwargs=kwargs,
            )
            remote_optim_futs.append(remote_optim_rref_fut)

        self.remote_optimizers = _wait_for_all(remote_optim_futs)

Note we will need to refactor all optimizers in torch.optim to follow the functional API, and then register the functional optimizers one by one to make all of them available in Distributed Optimizer.

Usage

The new distributed optimizer has exact same interface (APIs) as before, we exposed the same API as python, but do all the heavy lifting under the hood, upon DistributedOptimizer construction, it tries to find the corresponding functional optimizer match, and then construct the local optimizer, automatically turn optimizers in each worker into TorchScript to make it GIL free. Example usage is exactly the same with Python API:

import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

with dist_autograd.context() as context_id:
  # Forward pass.
  rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  loss = rref1.to_here() + rref2.to_here()

  # Backward pass.
  dist_autograd.backward(context_id, [loss.sum()])

  # Optimizer, pass in optim.Adagrad, DistributedOptimizer will
  # automatically convert/compile it to TorchScript (GIL-free)
  dist_optim = DistributedOptimizer(
     optim.Adagrad,
     [rref1, rref2],
     lr=0.05,
  )
  dist_optim.step(context_id)

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @jjlilley @osalpekar @jiayisuse @gmagogsfm @xush6528 @agolynski

@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue oncall: distributed Add this issue/PR to distributed oncall triage queue labels Oct 26, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Oct 26, 2020
@wanchaol wanchaol removed this from Need triage in JIT Triage Oct 26, 2020
@rohan-varma rohan-varma added module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer and removed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Nov 14, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Nov 14, 2020
@suo suo removed this from Need triage in JIT Triage Nov 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

3 participants