From 6d15dd3a8233c925e8fae681b6db42c31fada58a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 29 Jan 2021 16:06:49 -0500 Subject: [PATCH] Fix Dirichlet.arg_constraints event_dim --- test/distributions/test_distributions.py | 5 +++++ torch/distributions/dirichlet.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index f0bacf7eb282..543fa18ea1be 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -3926,6 +3926,11 @@ def test_params_constraints(self): except KeyError: continue # ignore optional parameters + # Check param shape is compatible with distribution shape. + self.assertGreaterEqual(value.dim(), constraint.event_dim) + value_batch_shape = value.shape[:value.dim() - constraint.event_dim] + torch.broadcast_shapes(dist.batch_shape, value_batch_shape) + if is_dependent(constraint): continue diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index 3594d47c7209..3a85110d2e8a 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -40,7 +40,7 @@ class Dirichlet(ExponentialFamily): concentration (Tensor): concentration parameter of the distribution (often referred to as alpha) """ - arg_constraints = {'concentration': constraints.positive} + arg_constraints = {'concentration': constraints.independent(constraints.positive, 1)} support = constraints.simplex has_rsample = True