From a990ff70012301301b641f3199025f8ffa45fcc1 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Tue, 2 Feb 2021 14:22:35 -0800 Subject: [PATCH] [SobolEngine] Fix edge case of dtype of first sample (#51578) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51578 https://github.com/pytorch/pytorch/pull/49710 introduced an edge case in which drawing a single sample resulted in ignoring the `dtype` arg to `draw`. This fixes this and adds a unit test to cover this behavior. Test Plan: Unit tests Reviewed By: danielrjiang Differential Revision: D26204393 fbshipit-source-id: 441a44dc035002e7bbe6b662bf6d1af0e2cd88f4 --- test/test_torch.py | 12 ++++++++++++ torch/quasirandom.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 8686bd1782d9..27ab59380bc4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1566,6 +1566,18 @@ def test_sobolengine_draw(self, scramble: bool = False): def test_sobolengine_draw_scrambled(self): self.test_sobolengine_draw(scramble=True) + def test_sobolengine_first_point(self): + for dtype in (torch.float, torch.double): + engine = torch.quasirandom.SobolEngine(2, scramble=False) + sample = engine.draw(1, dtype=dtype) + self.assertTrue(torch.all(sample == 0)) + self.assertEqual(sample.dtype, dtype) + for dtype in (torch.float, torch.double): + engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456) + sample = engine.draw(1, dtype=dtype) + self.assertTrue(torch.all(sample != 0)) + self.assertEqual(sample.dtype, dtype) + def test_sobolengine_continuing(self, scramble: bool = False): ref_sample = self._sobol_reference_samples(scramble=scramble) engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456) diff --git a/torch/quasirandom.py b/torch/quasirandom.py index 6f8a82fb2daf..c2268082efc8 100644 --- a/torch/quasirandom.py +++ b/torch/quasirandom.py @@ -83,7 +83,7 @@ def draw(self, n=1, out=None, dtype=torch.float32): """ if self.num_generated == 0: if n == 1: - result = self._first_point + result = self._first_point.to(dtype) else: result, self.quasi = torch._sobol_engine_draw( self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,