In [None]:
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from typing import Tuple

In [None]:
def create_data_loaders(rank, world_size, batch_size) -> Tuple[DataLoader, DataLoader]:
  transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
  dataset_loc = './mnist_data'

  train_dataset = datasets.MNIST(dataset_loc, download=True, train=True,
                                 transform=transform)

  sampler = DistributedSampler(train_dataset, 
                               num_replicas=world_size, # <-- world size만큼
                               rank=rank, # <-- 보통 0번째 device의 rank가 들어감
                               shuffle=True, # <-Must be True
                               seed=42)
  train_loader = DataLoader(train_dataset,
                            batch_size=batch_size, 
                            shuffle=False, # <- Must be False
                            num_workers=4,
                            sampler=sampler,
                            pin_memory=True)

  # test와 val은 distrbuted가 필요하지 않다.
  test_dataset = datasets.MNIST(dataset_loc,
                                download=True,
                                train=False,
                                transform=transform)
  test_loader = DataLoader(test_dataset,
                           batch_size = batch_size,
                           shuffle=True,
                           num_workers=4,
                           pin_memory=True)
  return train_loader, test_loader