Skip to content

Commit

Permalink
add Half and BFloat16 tests for poisson
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoE committed Nov 5, 2023
1 parent 0555666 commit f91eb43
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,7 @@ def _check_sampler_discrete(self, torch_dist, ref_dist, message,
num_samples=10000, failure_rate=1e-3):
"""Runs a Chi2-test for the support, but ignores tail instead of combining"""
torch_samples = torch_dist.sample((num_samples,)).squeeze()
torch_samples = torch_samples.float() if torch_samples.dtype == torch.bfloat16 else torch_samples
torch_samples = torch_samples.cpu().numpy()
unique, counts = np.unique(torch_samples, return_counts=True)
pmf = ref_dist.pmf(unique)
Expand Down Expand Up @@ -1463,11 +1464,15 @@ def ref_log_prob(ref_rate, idx, x, log_prob):
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_poisson_sample(self):
set_rng_seed(1) # see Note [Randomized statistical tests]
for rate in [0.1, 1.0, 5.0]:
self._check_sampler_discrete(Poisson(rate),
scipy.stats.poisson(rate),
f'Poisson(lambda={rate})',
failure_rate=1e-3)
saved_dtype = torch.get_default_dtype()
for dtype in [torch.float, torch.double, torch.bfloat16, torch.half]:
torch.set_default_dtype(dtype)
for rate in [0.1, 1.0, 5.0]:
self._check_sampler_discrete(Poisson(rate),
scipy.stats.poisson(rate),
f'Poisson(lambda={rate})',
failure_rate=1e-3)
torch.set_default_dtype(saved_dtype)

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
Expand Down

0 comments on commit f91eb43

Please sign in to comment.