In [None]:
import random

In [None]:
import torch

In [None]:
SEED = 42

In [None]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, samples=16):
        super().__init__()

        self.samples = samples
        self.data = torch.arange(samples)
    
    def __getitem__(self, idx):
        rand_idx = random.randint(0, len(self) - 1)
        output = self.data[rand_idx]

        return output
    
    def __len__(self):
        return self.samples

In [None]:
def fixed_seed_worker(worker_id):
    random.seed(SEED)

    print("id: {}, seed: {}".format(worker_id, SEED))

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    random.seed(worker_seed)

    print("id: {}, seed: {}".format(worker_id, worker_seed))

In [None]:
train_dataset = TrainDataset(samples=16)

## Invalid seed

In [None]:
random.seed(SEED)
torch.manual_seed(SEED)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=4,
    num_workers=2,
    worker_init_fn=fixed_seed_worker, # here
)

for data in train_loader:
    print(data)

## Valid seed

In [None]:
random.seed(SEED)
torch.manual_seed(SEED)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=4,
    num_workers=2,
    worker_init_fn=seed_worker, # here
)

for data in train_loader:
    print(data)