Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow scalar parameters in Dirichlet and Categorical #11589

Closed
wants to merge 4 commits into from

Conversation

neerajprad
Copy link
Contributor

@neerajprad neerajprad commented Sep 12, 2018

This adds a small check in Dirichlet and Categorical __init__ methods to ensure that scalar parameters are not admissible.

Motivation
Currently, Dirichlet throws no error when provided with a scalar parameter, but if we expand a scalar instance, it inherits the empty event shape from the original instance and gives unexpected results.

The alternative to this check is to promote event_shape to be torch.Size((1,)) if the original instance was a scalar, but that seems to add a bit more complexity (and changes the behavior of expand in that it would affect the event_shape as well as the batch_shape now). Does this seem reasonable? cc. @alicanb, @fritzo.

In [4]: d = dist.Dirichlet(torch.tensor(1.))

In [5]: d.sample()
Out[5]: tensor(1.0000)

In [6]: d.log_prob(d.sample())
Out[6]: tensor(0.)

In [7]: e = d.expand([3])

In [8]: e.sample()
Out[8]: tensor([0.3953, 0.1797, 0.4250])  # interpreted as events

In [9]: e.log_prob(e.sample())
Out[9]: tensor(0.6931)  # wrongly summed out

In [10]: e.batch_shape
Out[10]: torch.Size([3])

In [11]: e.event_shape
Out[11]: torch.Size([])  # cannot be empty

Additionally, based on review comments, this removes real_vector constraint. This was only being used in MultivariateNormal, but I am happy to revert this if we want to keep it around for backwards compatibility.

@neerajprad neerajprad changed the title Diallow scalar parameter in Dirichlet Disallow scalar parameter in Dirichlet Sep 12, 2018
@@ -54,6 +54,8 @@ class Dirichlet(ExponentialFamily):
has_rsample = True

def __init__(self, concentration, validate_args=None):
if concentration.dim() < 1:
raise ValueError("`concentration` parameter must be at least one-dimensional.")
self.concentration, = broadcast_all(concentration)

This comment was marked as off-topic.

@alicanb
Copy link
Collaborator

alicanb commented Sep 12, 2018

Nice catch @neerajprad ! I think we should also add this to Categoricals as well. While browsing I also realized we have a real_vector constraint, used in MultivariateNormal. We can get rid of that since we already do .all() while validating args (just to be clear, i'm asking your opinion, not bikeshedding 😄 )

@neerajprad
Copy link
Contributor Author

I think we should also add this to Categoricals as well.

Good point; I'll address it in this PR itself.

While browsing I also realized we have a real_vector constraint, used in MultivariateNormal. We can get rid of that since we already do .all() while validating args (just to be clear, i'm asking your opinion, not bikeshedding 😄 )

Hmm..I didn't quite know that we had a real_vector constraint as well! :) It seems like we are only using it in MultivariateNormal. Can we just remove the constraint altogether?

@neerajprad neerajprad changed the title Disallow scalar parameter in Dirichlet Disallow scalar parameters in Dirichlet and Categorical Sep 12, 2018
@fritzo
Copy link
Collaborator

fritzo commented Sep 13, 2018

IIRC the purpose of the real_vector constraint is to be able to correctly output a result of shape batch_shape from the .check() method, which is useful for verifying masked batches of distributions. If you think the real_vector constraint should be removed, please add an IndependentConstraint first, and ensure all transforms and bijections are correctly registered. This is sufficiently complex to warrant a separate PR.

@neerajprad
Copy link
Contributor Author

Thanks for providing that context, @fritzo. I’ll remove the last commit. I don’t have a strong opinion on its removal, and will leave it until further discussion.

@neerajprad
Copy link
Contributor Author

I think this should be good to merge, pending any further comments.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants