### Elastic Sampler

##### Example 1

In [14]:
data = torch.Tensor([1, 2, 3, 4, 5])

In [58]:
data

tensor([1., 2., 3., 4., 5.])

In [59]:
from torch.utils.data import Sampler

Create a new sampler that returns only the dataset indices with even values.

**Hint**: `range(0, len(self.data), 2)`

In [60]:
class EvenSampler(Sampler):
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        return iter([i for i in range(0, len(self.data), 2)])

    def __len__(self):
        return len(self.data)

In [61]:
sampler = EvenSampler(data=data)

In [62]:
samples = list(sampler) 

In [63]:
indices

[0, 2, 4]

##### Example 2

In [19]:
import torch

In [10]:
import math
from torch.utils.data import Sampler

In [16]:
def get_num_replicas():
    return 4

In [17]:
def get_rank():
    1

In [92]:
class ElasticSampler(Sampler):
    def __init__(self, dataset, seed):
        self.dataset = dataset
        self.seed = seed
        
        self.epoch = 0
        
        self.num_replicas = 0
        self.rank = 0
        self.remaining_indicies = []
        self.processed_indicies = set()
        self.num_processed = 0
    
    def set_epoch(self, epoch):
        self.epoch = epoch
        self.num_processed = 0
        self.reset()
        
    def record_batch(self, batch_size):
        self.num_processed += batch_size * self.num_replicas
    
    def reset(self):
        self.num_replicas = get_num_replicas()
        self.rank = get_rank()
        
        all_indicies = [idx for idx in range(len(self.dataset))]
        self.remaining_indicies = all_indicies[self.num_processed:]
        self.num_samples = int(math.ceil(len(self.remaining_indicies)*1.0 / self.num_replicas))
        self.total_size = self.num_replicas * self.num_samples
    
    def __iter__(self):
        self.indicies = self.remaining_indicies[:]
        
        # add extra last sampels to make it evenly divisible
        self.indicies += self.indicies[:(self.total_size - len(self.indicies))]
        assert len(self.indicies) == self.total_size
        
        self.indicies = self.indicies[self.rank:self.total_size:self.num_replicas]
        assert len(self.indicies) == self.num_samples
        
        return iter(self.indicies)

    def __len__(self):
        return self.num_samples

In [93]:
dataset = torch.arange(0, 16)

In [94]:
dataset

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

In [95]:
sampler = ElasticSampler(dataset, seed=69)

n_epochs * batch_size * num_replicas

In [96]:
n_epochs = 2
n_batches = 4
batch_size = 4
num_replicas = 4

In [97]:
for epoch in range(n_epochs):
    sampler.set_epoch(epoch)
    for batch_idx in range(n_batches):
        indicies = list(iter(sampler))
        print(indicies)
        sampler.record_batch(batch_size)
        print(sampler.num_processed)
        # assert sampler.num_processed == batch_size * (batch_idx + 1)

[0, 4, 8, 12]
16
[0, 4, 8, 12]
32
[0, 4, 8, 12]
48
[0, 4, 8, 12]
64
[0, 4, 8, 12]
16
[0, 4, 8, 12]
32
[0, 4, 8, 12]
48
[0, 4, 8, 12]
64


In [98]:
batch_size * (batch_idx + 1)

16