diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 53519416deead..a813b870ec240 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1547,26 +1547,6 @@ def test_sampler_reproducibility(self): ls[i].append(next(its[i])) self.assertEqual(ls[0], ls[1]) - def test_serialize_sampler(self): - from torch.utils.data import RandomSampler - sampler = RandomSampler(self.dataset, num_samples=5, replacement=True) - it1 = iter(sampler) - torch.manual_seed(0) - _ = next(it1) - - torch.manual_seed(0) - seed = torch.empty((), dtype=torch.int64).random_() - self.assertEqual(list(sampler._iterators.values())[0].initial_seed(), seed) - - it2 = iter(sampler) - torch.manual_seed(1) - _ = next(it2) - - torch.manual_seed(1) - seed = torch.empty((), dtype=torch.int64).random_() - _ = list(it1) - self.assertEqual(list(sampler._iterators.values())[0].initial_seed(), seed) - def _test_sampler(self, **kwargs): indices = range(2, 12) # using a regular iterable dl = self._get_data_loader(self.dataset, sampler=indices, batch_size=2, **kwargs) diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index a9923a8149d27..232302e53c541 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -89,8 +89,6 @@ def __init__(self, data_source: Sized, replacement: bool = False, self.replacement = replacement self._num_samples = num_samples self.generator = generator - # Used to save state of RNG per iterator - self._iterators = {} if not isinstance(self.replacement, bool): raise TypeError("replacement should be a boolean value, but got " @@ -112,6 +110,7 @@ def num_samples(self) -> int: return self._num_samples def __iter__(self) -> Iterator[int]: + n = len(self.data_source) if self.generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator() @@ -119,20 +118,12 @@ def __iter__(self) -> Iterator[int]: else: generator = self.generator - it = self.iter_fn(generator) - self._iterators[it] = generator - yield from it - del self._iterators[it] - - def iter_fn(self, rng) -> Iterator[int]: - n = len(self.data_source) - if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=rng).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=rng).tolist() + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: - yield from torch.randperm(n, generator=rng).tolist() + yield from torch.randperm(n, generator=generator).tolist() def __len__(self) -> int: return self.num_samples