Skip to content

Commit

Permalink
[SobolEngine] Fix edge case of dtype of first sample (#51578)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #51578

#49710 introduced an edge case in which
drawing a single sample resulted in ignoring the `dtype` arg to `draw`. This
fixes this and adds a unit test to cover this behavior.

Test Plan: Unit tests

Reviewed By: danielrjiang

Differential Revision: D26204393

fbshipit-source-id: 441a44dc035002e7bbe6b662bf6d1af0e2cd88f4
  • Loading branch information
Balandat authored and facebook-github-bot committed Feb 2, 2021
1 parent 4746b3d commit a990ff7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
12 changes: 12 additions & 0 deletions test/test_torch.py
Expand Up @@ -1566,6 +1566,18 @@ def test_sobolengine_draw(self, scramble: bool = False):
def test_sobolengine_draw_scrambled(self):
self.test_sobolengine_draw(scramble=True)

def test_sobolengine_first_point(self):
for dtype in (torch.float, torch.double):
engine = torch.quasirandom.SobolEngine(2, scramble=False)
sample = engine.draw(1, dtype=dtype)
self.assertTrue(torch.all(sample == 0))
self.assertEqual(sample.dtype, dtype)
for dtype in (torch.float, torch.double):
engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456)
sample = engine.draw(1, dtype=dtype)
self.assertTrue(torch.all(sample != 0))
self.assertEqual(sample.dtype, dtype)

def test_sobolengine_continuing(self, scramble: bool = False):
ref_sample = self._sobol_reference_samples(scramble=scramble)
engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
Expand Down
2 changes: 1 addition & 1 deletion torch/quasirandom.py
Expand Up @@ -83,7 +83,7 @@ def draw(self, n=1, out=None, dtype=torch.float32):
"""
if self.num_generated == 0:
if n == 1:
result = self._first_point
result = self._first_point.to(dtype)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
Expand Down

0 comments on commit a990ff7

Please sign in to comment.