diff --git a/references/classification/sampler.py b/references/classification/sampler.py index cfe95dd085a..a55e25a16b1 100644 --- a/references/classification/sampler.py +++ b/references/classification/sampler.py @@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler): https://github.com/facebookresearch/deit/blob/main/samplers.py """ - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available!") @@ -32,11 +32,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): self.total_size = self.num_samples * self.num_replicas self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) self.shuffle = shuffle + self.seed = seed def __iter__(self): # Deterministically shuffle based on epoch g = torch.Generator() - g.manual_seed(self.epoch) + g.manual_seed(self.seed + self.epoch) if self.shuffle: indices = torch.randperm(len(self.dataset), generator=g).tolist() else: diff --git a/references/classification/train.py b/references/classification/train.py index 689735d8717..8a942b99a5f 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -9,7 +9,7 @@ import torchvision import transforms import utils -from references.classification.sampler import RASampler +from sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate from torchvision.transforms.functional import InterpolationMode