Skip to content

Commit

Permalink
Fix Dirichlet.arg_constraints event_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Jan 29, 2021
1 parent dbfaf96 commit 6d15dd3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions test/distributions/test_distributions.py
Expand Up @@ -3926,6 +3926,11 @@ def test_params_constraints(self):
except KeyError:
continue # ignore optional parameters

# Check param shape is compatible with distribution shape.
self.assertGreaterEqual(value.dim(), constraint.event_dim)
value_batch_shape = value.shape[:value.dim() - constraint.event_dim]
torch.broadcast_shapes(dist.batch_shape, value_batch_shape)

if is_dependent(constraint):
continue

Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/dirichlet.py
Expand Up @@ -40,7 +40,7 @@ class Dirichlet(ExponentialFamily):
concentration (Tensor): concentration parameter of the distribution
(often referred to as alpha)
"""
arg_constraints = {'concentration': constraints.positive}
arg_constraints = {'concentration': constraints.independent(constraints.positive, 1)}
support = constraints.simplex
has_rsample = True

Expand Down

0 comments on commit 6d15dd3

Please sign in to comment.