### StateHandler

In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

In [25]:
model = Model()

In [26]:
states = {
    "model": model, "epoch": 1
}

In [27]:
states

{'model': Model(), 'epoch': 1}

In [28]:
import copy

Create handlers for the states:
- If an object is in the `handler_registry`, wrap a handler to handle `commit`, `restore`, and `sync`.
- If an object is not in the `handler_registry`, simply return it in `remainders`.

**Hint**: For the sync part, use `broadcast_parameters(state_dict)`, `value.load_state_dict()`

In [29]:
class ModelStateHandler:
    def __init__(self, model):
        self.value = model
        # store this as an initial commit
        self._model_state = copy.deepcopy(self.value.state_dict())

    def commit(self):
        self._model_state = copy.deepcopy(self.value.state_dict())

    def restore(self):
        self.value.load_state_dict(self._model_state)

    def sync(self):
        broadcast_parameters(self.value.state_dict())
    
    def set_value(self, value):
        self.value = value

In [30]:
def get_handler(v):
    for handler_type, handler_cls in handler_registery:
        if isinstance(v, handler_type):
            return handler_cls(v)
    return None

In [31]:
def get_handlers(states):
    handlers = {}
    remainders = {}
    for k, v in states.items():
        handler = get_handler(v)
        if handler == None:
            remainders[k] = v
        else:
            handlers[k] = handler
    return handlers, remainders

In [32]:
handler_registery = [
    (torch.nn.Module, ModelStateHandler),
]

In [33]:
handlers, remainders = get_handlers(states)

In [34]:
handlers, remainders

({'model': <__main__.ModelStateHandler at 0x7fe80926df40>}, {'epoch': 1})

### Elastic Sampler

##### Example 1

In [14]:
data = torch.Tensor([1, 2, 3, 4, 5])

In [58]:
data

tensor([1., 2., 3., 4., 5.])

In [59]:
from torch.utils.data import Sampler

Create a new sampler that returns only the dataset indices with even values.

**Hint**: `range(0, len(self.data), 2)`

In [60]:
class EvenSampler(Sampler):
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        return iter([i for i in range(0, len(self.data), 2)])

    def __len__(self):
        return len(self.data)

In [61]:
sampler = EvenSampler(data=data)

In [62]:
samples = list(sampler) 

In [63]:
indices

[0, 2, 4]

##### Example 2

In [19]:
import torch

In [10]:
import math
from torch.utils.data import Sampler

In [16]:
def get_num_replicas():
    return 4

In [17]:
def get_rank():
    1

In [92]:
class ElasticSampler(Sampler):
    def __init__(self, dataset, seed):
        self.dataset = dataset
        self.seed = seed
        
        self.epoch = 0
        
        self.num_replicas = 0
        self.rank = 0
        self.remaining_indicies = []
        self.processed_indicies = set()
        self.num_processed = 0
    
    def set_epoch(self, epoch):
        self.epoch = epoch
        self.num_processed = 0
        self.reset()
        
    def record_batch(self, batch_size):
        self.num_processed += batch_size * self.num_replicas
    
    def reset(self):
        self.num_replicas = get_num_replicas()
        self.rank = get_rank()
        
        all_indicies = [idx for idx in range(len(self.dataset))]
        self.remaining_indicies = all_indicies[self.num_processed:]
        self.num_samples = int(math.ceil(len(self.remaining_indicies)*1.0 / self.num_replicas))
        self.total_size = self.num_replicas * self.num_samples
    
    def __iter__(self):
        self.indicies = self.remaining_indicies[:]
        
        # add extra last sampels to make it evenly divisible
        self.indicies += self.indicies[:(self.total_size - len(self.indicies))]
        assert len(self.indicies) == self.total_size
        
        self.indicies = self.indicies[self.rank:self.total_size:self.num_replicas]
        assert len(self.indicies) == self.num_samples
        
        return iter(self.indicies)

    def __len__(self):
        return self.num_samples

In [93]:
dataset = torch.arange(0, 16)

In [94]:
dataset

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

In [95]:
sampler = ElasticSampler(dataset, seed=69)

n_epochs * batch_size * num_replicas

In [96]:
n_epochs = 2
n_batches = 4
batch_size = 4
num_replicas = 4

In [97]:
for epoch in range(n_epochs):
    sampler.set_epoch(epoch)
    for batch_idx in range(n_batches):
        indicies = list(iter(sampler))
        print(indicies)
        sampler.record_batch(batch_size)
        print(sampler.num_processed)
        # assert sampler.num_processed == batch_size * (batch_idx + 1)

[0, 4, 8, 12]
16
[0, 4, 8, 12]
32
[0, 4, 8, 12]
48
[0, 4, 8, 12]
64
[0, 4, 8, 12]
16
[0, 4, 8, 12]
32
[0, 4, 8, 12]
48
[0, 4, 8, 12]
64


In [98]:
batch_size * (batch_idx + 1)

16

### TorchState

##### Example 1

In [48]:
class State:
    def __init__(self):
        pass

In [None]:
class TorchState:
    def __init__(self, model, optimizer, **kwargs):
        kwargs.update({"model": model, "optimizer": optimizer})
        self._handlers, kwargs = get_

### `run_func`

##### Example 1

In [37]:
class HostsUpdatedInterrupt(Exception):
    def __init__(self, skip_sync):
        self.skip_sync = skip_sync

Write a custom error with attributes as bellow

In [38]:
try:
    raise HostsUpdatedInterrupt(skip_sync=False)
except HostsUpdatedInterrupt as err:
    print(err.skip_sync)

False


##### Example 2

In [40]:
class HorovodInternalError(Exception): pass
class HostsUpdatedInterrupt(Exception): pass

In [24]:
def train_one_batch(): pass

In [46]:
def _reset():
    shutdown()
    init()

In [47]:
def run_fn(func, reset):
    def wrapper(state, *args, **kwargs):
        notification_manager.init()
        notification_manager.register_listener(state)
        skip_sync = False
        
        try:
            while True:
                try:
                    if not skip_sync:
                        state.sync()
                
                    return func(state, *args, **kwargs)
                except HorovodInternalError:
                    # a worker failed
                    state.restore()
                    skip_sync = False
                except HostsUpdatedInterrupt as e:
                    skip_sync = e.skip_sync
                
                reset()
                state.on_reset()
        finally:
            notification_manager.remove_listener(state)
        return wrapper

In [None]:
skip_