Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow scalar parameters in Dirichlet and Categorical #11589

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions torch/distributions/categorical.py
Expand Up @@ -45,8 +45,12 @@ def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
if probs.dim() < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs / probs.sum(-1, keepdim=True)
else:
if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1]
Expand Down
4 changes: 3 additions & 1 deletion torch/distributions/dirichlet.py
Expand Up @@ -54,7 +54,9 @@ class Dirichlet(ExponentialFamily):
has_rsample = True

def __init__(self, concentration, validate_args=None):
self.concentration, = broadcast_all(concentration)
if concentration.dim() < 1:
raise ValueError("`concentration` parameter must be at least one-dimensional.")
self.concentration = concentration
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super(Dirichlet, self).__init__(batch_shape, event_shape, validate_args=validate_args)

Expand Down