In [1]:
from torch.utils.data import Dataset, DataLoader
from typing import Any, Iterable
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group
import os

In [2]:
class DataUnit:
    def __init__(self, value: int) -> None:
        self.value = value

    def __repr__(self) -> str:
        return f"DataUnit({self.value})"

In [3]:
data_source = range(100)

In [4]:
class CustomDataset(Dataset):
    def __init__(self, data: Iterable) -> None:
        super().__init__()
        self.data = [DataUnit(val) for val in data]

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index) -> Any:
        return self.data[index]
    

class CustomDataLoader(DataLoader):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs, collate_fn=self.collate_fn)

    def collate_fn(self, batch) -> list:
        return batch

In [None]:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
init_process_group(backend="nccl", rank=0, world_size=2)

In [None]:
dataset = CustomDataset(data_source)

In [9]:
dataloader = CustomDataLoader(dataset, batch_size=10, sampler=DistributedSampler(dataset))

In [10]:
for item in dataloader:
    print(item)

[DataUnit(44), DataUnit(19), DataUnit(93), DataUnit(90), DataUnit(71), DataUnit(69), DataUnit(37), DataUnit(95), DataUnit(53), DataUnit(91)]
[DataUnit(81), DataUnit(42), DataUnit(80), DataUnit(85), DataUnit(74), DataUnit(56), DataUnit(76), DataUnit(63), DataUnit(82), DataUnit(40)]
[DataUnit(26), DataUnit(92), DataUnit(57), DataUnit(10), DataUnit(16), DataUnit(66), DataUnit(89), DataUnit(41), DataUnit(97), DataUnit(8)]
[DataUnit(31), DataUnit(24), DataUnit(35), DataUnit(30), DataUnit(65), DataUnit(7), DataUnit(98), DataUnit(23), DataUnit(20), DataUnit(29)]
[DataUnit(78), DataUnit(61), DataUnit(94), DataUnit(15), DataUnit(4), DataUnit(52), DataUnit(59), DataUnit(5), DataUnit(54), DataUnit(46)]
[DataUnit(3), DataUnit(28), DataUnit(2), DataUnit(70), DataUnit(6), DataUnit(60), DataUnit(49), DataUnit(68), DataUnit(55), DataUnit(72)]
[DataUnit(79), DataUnit(77), DataUnit(45), DataUnit(1), DataUnit(32), DataUnit(34), DataUnit(11), DataUnit(0), DataUnit(22), DataUnit(12)]
[DataUnit(87), DataUni

In [49]:
dataloader.sampler.set_epoch(2)

In [50]:
for item in dataloader:
    print(item)

[DataUnit(8), DataUnit(7), DataUnit(1)]
[DataUnit(5), DataUnit(6), DataUnit(9)]
[DataUnit(0), DataUnit(4), DataUnit(2)]
[DataUnit(3)]
