<a href="https://colab.research.google.com/github/KeisukeShimokawa/papers-challenge/blob/master/tips/torch/Fast_Batch_Sampler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch

In [0]:
inputs = torch.randn(1000000,10)
labels = torch.randint(low=0, high=10, size=(1000000,))
batch_size = 10000

In [0]:
def run_loader(loader):
    for inputs, labels in loader:
        pass

In [4]:
dataset = torch.utils.data.TensorDataset(inputs, labels)
loader1 = torch.utils.data.DataLoader(dataset,
                                      batch_size=batch_size,
                                      shuffle=True)

%timeit -n1 -r1 run_loader(loader1)

1 loop, best of 1: 6.69 s per loop


In [0]:
from torch.utils.data import Sampler
from torch._six import int_classes as _int_classes


class CustomBatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.
    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        # convert to list comprehension
        batch = [idx for counter, idx in enumerate(self.sampler)
                 if counter < self.batch_size]
        if len(batch) == self.batch_size:
            yield batch
            batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    # def __iter__(self):
    #     batch = []
    #     for idx in self.sampler:
    #         batch.append(idx)
    #         if len(batch) == self.batch_size:
    #             yield batch
    #             batch = []
    #     if len(batch) > 0 and not self.drop_last:
    #         yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size


In [0]:
batch_sampler = CustomBatchSampler(
    torch.utils.data.RandomSampler(dataset),
    batch_size=batch_size,
    drop_last=True
)

In [8]:
dataset = torch.utils.data.TensorDataset(inputs, labels)
loader2 = torch.utils.data.DataLoader(dataset,
                                      batch_sampler=batch_sampler)

%timeit -n1 -r1 run_loader(loader2)

1 loop, best of 1: 270 ms per loop
