New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disallow scalar parameters in Dirichlet and Categorical #11589
Conversation
torch/distributions/dirichlet.py
Outdated
@@ -54,6 +54,8 @@ class Dirichlet(ExponentialFamily): | |||
has_rsample = True | |||
|
|||
def __init__(self, concentration, validate_args=None): | |||
if concentration.dim() < 1: | |||
raise ValueError("`concentration` parameter must be at least one-dimensional.") | |||
self.concentration, = broadcast_all(concentration) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Nice catch @neerajprad ! I think we should also add this to |
Good point; I'll address it in this PR itself.
Hmm..I didn't quite know that we had a real_vector constraint as well! :) It seems like we are only using it in |
IIRC the purpose of the |
Thanks for providing that context, @fritzo. I’ll remove the last commit. I don’t have a strong opinion on its removal, and will leave it until further discussion. |
2aec45e
to
45e314d
Compare
I think this should be good to merge, pending any further comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This adds a small check in
Dirichlet
andCategorical
__init__
methods to ensure that scalar parameters are not admissible.Motivation
Currently,
Dirichlet
throws no error when provided with a scalar parameter, but if weexpand
a scalar instance, it inherits the empty event shape from the original instance and gives unexpected results.The alternative to this check is to promote
event_shape
to betorch.Size((1,))
if the original instance was a scalar, but that seems to add a bit more complexity (and changes the behavior ofexpand
in that it would affect theevent_shape
as well as thebatch_shape
now). Does this seem reasonable? cc. @alicanb, @fritzo.Additionally, based on review comments, this removes
real_vector
constraint. This was only being used inMultivariateNormal
, but I am happy to revert this if we want to keep it around for backwards compatibility.