In [1]:
from collections import OrderedDict
from typing import Dict, List
import torch

In [2]:
class ZeRO(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        optim,
        group,
        defaults
    ):
        super().__init__(params, defaults)
        self.world_size = torch.distributed.get_world_size(group=group)
        # self._per_device_params: Dict[torch.device, List[List[torch.nn.Parameter]]] = OrderedDict()
        self.setup_local_optimizer()
    
    def setup_local_optimizer(self):
        pass
    
    def partition_parameters(self, params):
        self._partrition_parameters = [[] for _ in range(self.world_size)]

##### Example 1

In [4]:
from torch import nn
from torch.optim import Adam

In [5]:
model = nn.Sequential(
    nn.Linear(2, 4),
    nn.ReLU(),
    nn.Linear(4, 2)
)

In [6]:
optim = Adam(model.parameters())

In [7]:
for p in model.parameters():
    print(p)

Parameter containing:
tensor([[-0.3173, -0.1710],
        [-0.2532, -0.2532],
        [ 0.3124,  0.4826],
        [ 0.3922,  0.1965]], requires_grad=True)
Parameter containing:
tensor([ 0.3268, -0.2788,  0.6120,  0.1722], requires_grad=True)
Parameter containing:
tensor([[ 0.4952, -0.0644, -0.3629,  0.4602],
        [-0.2251,  0.3302, -0.0634,  0.2421]], requires_grad=True)
Parameter containing:
tensor([0.3481, 0.3911], requires_grad=True)


In [8]:
for param in model.parameters():
    optim.zero_grad()
    output = model(torch.randn(5, 2))
    loss = output.mean()
    loss.backward()
    optim.step()

In [9]:
optim.param_groups

[{'params': [Parameter containing:
   tensor([[-0.3156, -0.1681],
           [-0.2498, -0.2506],
           [ 0.3151,  0.4834],
           [ 0.3889,  0.1958]], requires_grad=True),
   Parameter containing:
   tensor([ 0.3229, -0.2820,  0.6159,  0.1683], requires_grad=True),
   Parameter containing:
   tensor([[ 0.4915, -0.0673, -0.3669,  0.4565],
           [-0.2288,  0.3273, -0.0673,  0.2383]], requires_grad=True),
   Parameter containing:
   tensor([0.3441, 0.3871], requires_grad=True)],
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'differentiable': False,
  'fused': None}]

In [10]:
optim.state

defaultdict(dict,
            {Parameter containing:
             tensor([[-0.3156, -0.1681],
                     [-0.2498, -0.2506],
                     [ 0.3151,  0.4834],
                     [ 0.3889,  0.1958]], requires_grad=True): {'step': tensor(4.),
              'exp_avg': tensor([[-0.0124,  0.0008],
                      [-0.0123, -0.0048],
                      [-0.0139, -0.0166],
                      [ 0.0411,  0.0291]]),
              'exp_avg_sq': tensor([[1.5594e-05, 1.5520e-05],
                      [6.7088e-06, 3.2181e-06],
                      [5.1182e-05, 4.6074e-05],
                      [1.7123e-04, 1.0390e-04]])},
             Parameter containing:
             tensor([ 0.3229, -0.2820,  0.6159,  0.1683], requires_grad=True): {'step': tensor(4.),
              'exp_avg': tensor([ 0.0392,  0.0089, -0.0633,  0.0797]),
              'exp_avg_sq': tensor([5.3506e-05, 4.1894e-06, 1.3817e-04, 2.1951e-04])},
             Parameter containing:
             tensor([[