# Zero Redundancy Optimizer

> Fill in a module description here

In [7]:
#| default_exp optim.zero

In [1]:
#| hide
from nbdev.showdoc import *

In [2]:
#| hide
import nbdev; nbdev.nbdev_export()

In [23]:
#| export
from typing import List

import torch
from torch import nn
from torch.optim import Optimizer, SGD

from fastgoose.mpu.parallel import ParallelState

In [24]:
#| export
class ZeRO(Optimizer):
    def __init__(
        self,
        params,
        optim: Optimizer,
        parallel_state: ParallelState,
        defaults: dict = dict()
    ):
        super().__init__(params, defaults=defaults)
        self.optim = optim
        self.parallel_state = parallel_state
        
        self._init_local_optimizer()
    
    @property
    def rank(self) -> int:
        """Return the rank of the current process."""
        return self.parallel_state.rank

    @property
    def world_size(self) -> int:
        """Return the number of processes participating in the job."""
        return self.parallel_state.world_size
    
    def _init_local_optimizer(self):
        """Initialize the local optimizer for the current rank."""
        rank = self.parallel_state.rank
        params_per_rank = self.param_to_ranks()
        param_of_current_rank = params_per_rank[rank]
        self.local_optim = self.optim(param_of_current_rank, **self.defaults)
    
    def param_to_ranks(self) -> List[List[torch.Tensor]]:
        """Partition the parameters across the ranks."""
        for param_group in self.param_groups:
            params_per_rank = self._partrition_paramaters_per_rank(param_group['params'])
            return params_per_rank
    
    def _partrition_paramaters_per_rank(self, param_list) -> List[List[torch.Tensor]]:
        """Partition the parameters across the ranks."""
        numel_per_rank = [0 for _ in range(self.world_size)]
        param_per_rank = [[] for _ in range(self.world_size)]
        sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
        
        for param in sorted_params:
            rank = numel_per_rank.index(min(numel_per_rank))
            numel_per_rank[rank] += param.numel()
            param_per_rank[rank].append(param)
        
        return param_per_rank

In [25]:
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))

In [26]:
parallel_state = ParallelState(
    tensor_parallel_size=2,
    pipeline_parallel_size=4,
    data_parallel_size=2,
    world_size=16,
    master_addr='localhost',
    master_port=1234,
    backend='nccl'
)

In [27]:
zero_optimizer = ZeRO(model.parameters(), SGD, parallel_state=parallel_state, defaults={"lr": 1e-3})

In [28]:
zero_optimizer.param_to_ranks()

[[Parameter containing:
  tensor([[ 0.0130, -0.2725,  0.1036, -0.2668, -0.0091, -0.1629,  0.0786, -0.1901,
           -0.3070, -0.1132],
          [-0.1940, -0.1832,  0.0139, -0.2267,  0.1815,  0.0156,  0.1376, -0.1220,
           -0.1053,  0.0153],
          [-0.0230, -0.2597,  0.1408, -0.0549, -0.1920, -0.0295, -0.2258, -0.2049,
           -0.2305, -0.1301],
          [-0.1884, -0.0528,  0.2176,  0.1626, -0.1197,  0.0237,  0.2957,  0.2946,
            0.0591,  0.0059],
          [-0.0655,  0.2768, -0.2711, -0.0580, -0.0561,  0.1394, -0.2170,  0.0078,
           -0.1938,  0.2141],
          [-0.1597,  0.0672,  0.0383, -0.2570, -0.1070, -0.2008,  0.1370, -0.2971,
           -0.2826,  0.1393],
          [ 0.2414, -0.1806,  0.0412, -0.0528,  0.0788,  0.2607, -0.1962,  0.1306,
           -0.1389, -0.0235],
          [-0.2913, -0.2463, -0.1943,  0.1044, -0.1574,  0.1706, -0.1575,  0.0664,
           -0.2615,  0.1355],
          [ 0.1876,  0.1476,  0.1083, -0.1976, -0.1943, -0.0501, -0.0293