Skip to content

Commit

Permalink
Update on "Convert generator in Sampler back to lazy construction"
Browse files Browse the repository at this point in the history
Fixes #63609


Differential Revision: [D30451774](https://our.internmc.facebook.com/intern/diff/D30451774)

[ghstack-poisoned]
  • Loading branch information
ejguan committed Aug 20, 2021
2 parents 1b1f1e9 + a68daaa commit 8251153
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
9 changes: 9 additions & 0 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,15 @@ def test_sampler_reproducibility(self):
l2 = list(sampler) + list(sampler)
self.assertEqual(l1, l2)

its = (iter(sampler), iter(sampler))
ls = ([], [])
for idx in range(len(sampler)):
for i in range(2):
if idx == 0:
torch.manual_seed(0)
ls[i].append(next(its[i]))
self.assertEqual(ls[0], ls[1])

def _test_sampler(self, **kwargs):
indices = range(2, 12) # using a regular iterable
dl = self._get_data_loader(self.dataset, sampler=indices, batch_size=2, **kwargs)
Expand Down
20 changes: 9 additions & 11 deletions torch/utils/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(self, data_source: Sized, replacement: bool = False,
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
self._gen: Optional[torch.Generator] = None

if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
Expand All @@ -113,19 +112,17 @@ def num_samples(self) -> int:
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
if self._gen is None:
self._gen = torch.Generator()
self._gen.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
else:
self._gen = self.generator
generator = self.generator

if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self._gen).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self._gen).tolist()
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=self._gen).tolist()
self._gen = None
yield from torch.randperm(n, generator=generator).tolist()

def __len__(self) -> int:
return self.num_samples
Expand All @@ -145,7 +142,8 @@ def __init__(self, indices: Sequence[int], generator=None) -> None:
self.generator = generator

def __iter__(self) -> Iterator[int]:
return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]

def __len__(self) -> int:
return len(self.indices)
Expand Down Expand Up @@ -188,7 +186,7 @@ def __init__(self, weights: Sequence[float], num_samples: int,

def __iter__(self) -> Iterator[int]:
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
return iter(rand_tensor.tolist())
yield from iter(rand_tensor.tolist())

def __len__(self) -> int:
return self.num_samples
Expand Down

0 comments on commit 8251153

Please sign in to comment.