Skip to content

Commit

Permalink
Update on "Convert generator in Sampler back to lazy construction"
Browse files Browse the repository at this point in the history
Fixes #63609


- Revert #63026 
  - Sampler is expected to be re-seeded if user specify seed before each epoch
  - Can not attach generator to self with `__iter__` because multiple iterators will ruin the use case
- Add tests to prevent the same case for different Samplers


Differential Revision: [D30451774](https://our.internmc.facebook.com/intern/diff/D30451774)

[ghstack-poisoned]
  • Loading branch information
ejguan committed Sep 29, 2021
2 parents 725eb13 + 0f695b8 commit d506308
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 33 deletions.
20 changes: 0 additions & 20 deletions test/test_dataloader.py
Expand Up @@ -1547,26 +1547,6 @@ def test_sampler_reproducibility(self):
ls[i].append(next(its[i]))
self.assertEqual(ls[0], ls[1])

def test_serialize_sampler(self):
from torch.utils.data import RandomSampler
sampler = RandomSampler(self.dataset, num_samples=5, replacement=True)
it1 = iter(sampler)
torch.manual_seed(0)
_ = next(it1)

torch.manual_seed(0)
seed = torch.empty((), dtype=torch.int64).random_()
self.assertEqual(list(sampler._iterators.values())[0].initial_seed(), seed)

it2 = iter(sampler)
torch.manual_seed(1)
_ = next(it2)

torch.manual_seed(1)
seed = torch.empty((), dtype=torch.int64).random_()
_ = list(it1)
self.assertEqual(list(sampler._iterators.values())[0].initial_seed(), seed)

def _test_sampler(self, **kwargs):
indices = range(2, 12) # using a regular iterable
dl = self._get_data_loader(self.dataset, sampler=indices, batch_size=2, **kwargs)
Expand Down
17 changes: 4 additions & 13 deletions torch/utils/data/sampler.py
Expand Up @@ -89,8 +89,6 @@ def __init__(self, data_source: Sized, replacement: bool = False,
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
# Used to save state of RNG per iterator
self._iterators = {}

if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
Expand All @@ -112,27 +110,20 @@ def num_samples(self) -> int:
return self._num_samples

def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator

it = self.iter_fn(generator)
self._iterators[it] = generator
yield from it
del self._iterators[it]

def iter_fn(self, rng) -> Iterator[int]:
n = len(self.data_source)

if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=rng).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=rng).tolist()
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=rng).tolist()
yield from torch.randperm(n, generator=generator).tolist()

def __len__(self) -> int:
return self.num_samples
Expand Down

0 comments on commit d506308

Please sign in to comment.