Skip to content

Commit

Permalink
Disallow scalar parameters in Dirichlet and Categorical (#11589)
Browse files Browse the repository at this point in the history
Summary:
This adds a small check in `Dirichlet` and `Categorical` `__init__` methods to ensure that scalar parameters are not admissible.

**Motivation**
Currently, `Dirichlet` throws no error when provided with a scalar parameter, but if we `expand` a scalar instance, it inherits the empty event shape from the original instance and gives unexpected results.

The alternative to this check is to promote `event_shape` to be `torch.Size((1,))` if the original instance was a scalar, but that seems to add a bit more complexity (and changes the behavior of `expand` in that it would affect the `event_shape` as well as the `batch_shape` now). Does this seem reasonable? cc. alicanb, fritzo.

```python
In [4]: d = dist.Dirichlet(torch.tensor(1.))

In [5]: d.sample()
Out[5]: tensor(1.0000)

In [6]: d.log_prob(d.sample())
Out[6]: tensor(0.)

In [7]: e = d.expand([3])

In [8]: e.sample()
Out[8]: tensor([0.3953, 0.1797, 0.4250])  # interpreted as events

In [9]: e.log_prob(e.sample())
Out[9]: tensor(0.6931)  # wrongly summed out

In [10]: e.batch_shape
Out[10]: torch.Size([3])

In [11]: e.event_shape
Out[11]: torch.Size([])  # cannot be empty
```

Additionally, based on review comments, this removes `real_vector` constraint. This was only being used in `MultivariateNormal`, but I am happy to revert this if we want to keep it around for backwards compatibility.
Pull Request resolved: #11589

Differential Revision: D9818271

Pulled By: soumith

fbshipit-source-id: f9bbba90ed6f04e0b5bdfa169e70ca20b280fc74
  • Loading branch information
neerajprad authored and facebook-github-bot committed Sep 14, 2018
1 parent c391c20 commit cda71e2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 4 additions & 0 deletions torch/distributions/categorical.py
Expand Up @@ -45,8 +45,12 @@ def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
if probs.dim() < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs / probs.sum(-1, keepdim=True)
else:
if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1]
Expand Down
4 changes: 3 additions & 1 deletion torch/distributions/dirichlet.py
Expand Up @@ -54,7 +54,9 @@ class Dirichlet(ExponentialFamily):
has_rsample = True

def __init__(self, concentration, validate_args=None):
self.concentration, = broadcast_all(concentration)
if concentration.dim() < 1:
raise ValueError("`concentration` parameter must be at least one-dimensional.")
self.concentration = concentration
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super(Dirichlet, self).__init__(batch_shape, event_shape, validate_args=validate_args)

Expand Down

0 comments on commit cda71e2

Please sign in to comment.