# StatefulなRandomSamplerクラス
* Epoch途中の任意のバッチから再開する機能を持つRandomSamplerクラス．
* generatorのstateおよび最後に出力したsample indexを記憶することで実現している．
* 任意のDatasetおよびDataLoaderと組み合わせて使用できる．

In [1]:
import torch

In [1]:
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
from torch.utils.data.sampler import RandomSampler

class ContinuableRandomSampler(RandomSampler):
    
    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, 
                 resume_index: int = None, generator_state=None):
        generator = torch.Generator()
        if generator_state is not None:
            generator.set_state(generator_state)
        else:
            pass
        super().__init__(data_source, replacement, num_samples, generator)
        self.resume_index = resume_index
        
    def __iter__(self) -> Iterator[int]:
        # get the RNG state before executing random permutation
        self.generator_state = self.generator.get_state()
        
        lst_sample_indices = super().__iter__()
        continue_flag = False if self.resume_index is None else True
        for sample_idx in lst_sample_indices:
            if continue_flag:
                if sample_idx != self.resume_index:
                    continue
                else:
                    continue_flag = False
                    continue

            self.resume_index = sample_idx
            yield sample_idx
        
        # forget resume index because we don't need it anymore.
        self.resume_index = None           
           
    def get_resume_index(self):
        return self.resume_index
    
    def get_generator_state(self):
        return self.generator_state

## 動作デモ

* ToyDataset: lowからhighまでの値を返すDatasetクラス

In [5]:
from torch.utils.data import Dataset

class ToyDataset(Dataset):
    def __init__(self, low, high):
        self.samples = list(range(low, high))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

### 初回の実行
* ContinuableRandomSampler() をインスタンス化して DataLoader() の sampler に引き渡す

In [6]:
from torch.utils.data import DataLoader

In [7]:
dataset = ToyDataset(low=10, high=30)
sampler = ContinuableRandomSampler(dataset)
data_loader = DataLoader(dataset=dataset, batch_size=4, sampler=sampler)

* 仮に batch index=1 から再開したいとしよう．
* 再開したいところで resume_index と generator_state を保存する

In [8]:
for idx, batch in enumerate(data_loader):
    print(batch)
    if idx == 1:
        resume_index = data_loader.sampler.get_resume_index()
        generator_state = data_loader.sampler.get_generator_state()
        print(f"batch idx: {idx}, ここから再開できます")

tensor([20, 10, 23, 28])
tensor([27, 29, 13, 11])
batch idx: 1, ここから再開できます
tensor([25, 16, 26, 18])
tensor([15, 21, 24, 17])
tensor([19, 12, 14, 22])


In [9]:
del sampler
del dataset
del data_loader

### やめたところから再開
* ContinuableRandomSampler() に，保存しておいた resume_index と generator_state を渡す．

In [10]:
dataset = ToyDataset(low=10, high=30)
sampler = ContinuableRandomSampler(dataset, resume_index=resume_index, generator_state=generator_state)
data_loader = DataLoader(dataset=dataset, batch_size=4, sampler=sampler)

In [11]:
for idx, batch in enumerate(data_loader):
    print(batch)

tensor([25, 16, 26, 18])
tensor([15, 21, 24, 17])
tensor([19, 12, 14, 22])
