# Zero Redundancy Optimizer

> Fill in a module description here

In [7]:
#| default_exp optim.zero

In [8]:
#| hide
from nbdev.showdoc import *

In [9]:
#| hide
import nbdev; nbdev.nbdev_export()

In [15]:
#| export
import torch
from torch import nn
from torch.optim import Optimizer, SGD

In [25]:
#| export
class ZeRO(Optimizer):
    def __init__(
        self,
        params,
        optim: Optimizer = SGD,
        world_size: int = 1,
        defaults: dict = dict()
    ):
        super().__init__(params, defaults=defaults)
        self.optim = optim
        self.world_size = world_size
        
        self._init_local_optimizer()
    
    def _init_local_optimizer(self):
        pass
    
    def param_to_ranks(self):
        for param_group in self.param_groups:
            params_per_rank = self._partrition_paramaters_per_rank(param_group['params'])
            return params_per_rank
    
    def _partrition_paramaters_per_rank(self, param_list):
        numel_per_rank = [0 for _ in range(self.world_size)]
        param_per_rank = [[] for _ in range(self.world_size)]
        sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
        
        for param in sorted_params:
            rank = numel_per_rank.index(min(numel_per_rank))
            numel_per_rank[rank] += param.numel()
            param_per_rank[rank].append(param)
        
        return param_per_rank

In [26]:
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))

In [27]:
zero_optimizer = ZeRO(model.parameters(), SGD, world_size=2)

In [30]:
zero_optimizer.param_to_ranks()

[[Parameter containing:
  tensor([[ 2.3553e-01, -3.1116e-01, -2.9379e-01, -2.5907e-01, -7.1481e-02,
           -2.6121e-01,  1.5681e-01, -2.5376e-01,  1.7853e-01,  3.7766e-02],
          [ 2.0894e-01, -2.3549e-01, -1.9249e-01,  2.4569e-01, -1.1339e-01,
            5.3849e-02,  1.3157e-01, -1.0871e-01, -5.2743e-02,  1.1673e-01],
          [ 2.8707e-01,  7.6804e-02,  2.2762e-02,  1.3525e-01,  6.1164e-02,
            1.1460e-01,  1.9687e-01,  3.0281e-01,  1.2798e-01,  1.5961e-01],
          [ 3.0327e-01, -4.9095e-02,  2.5742e-01,  6.4522e-02, -1.4047e-01,
           -1.3231e-01, -1.4638e-01,  1.5712e-01,  1.2123e-01,  1.5444e-01],
          [ 2.2958e-01, -9.4924e-02, -1.0883e-01, -1.7234e-01,  3.0239e-01,
            3.0569e-01,  1.0597e-03,  1.9708e-01,  2.1390e-01,  1.8922e-01],
          [-1.2181e-01,  1.0332e-02,  1.2949e-01,  2.9991e-01, -1.4892e-01,
            5.1659e-02,  2.2322e-01, -7.5470e-02,  1.3945e-01, -3.4895e-03],
          [-2.8630e-01, -1.1013e-01, -1.0076e-02,  2.3633e