-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MultivariateNormalDiag distribution #11178
Conversation
I'm not sure we actually need this. This can be done with |
@alicanb is correct, this can be accomplished via |
@fritzo there is an example in |
I wasn't aware of the Are there any performance benefits to having a dedicated implementation? |
Have you tried benchmarking @samuela? |
@zou3519 I have not. I assumed that this PR was sort of doomed based on the discussion. |
I think there will be negligible performance benefit, since def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(sample_shape)
def log_prob(self, value):
log_prob = self.base_dist.log_prob(value)
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims) Also IMO it is cleaner to provide a general |
@fritzo I find myself using a lot of diagonal normal distributions in my work. For me it's just cleaner and communicates intent more clearly to write |
raise ValueError("Incompatible batch shapes: loc {}, scale_diag {}" | ||
.format(loc.shape, scale_diag.shape)) | ||
batch_shape = self.loc.shape[:-1] | ||
super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, yes. will fix
Can we replace this complex PR in favor of @samuela's simple wrapper? def MultivariateNormalDiag(loc, scale_diag):
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")
return Independent(Normal(loc, scale_diag), 1) EDIT sorry if I misunderstand the purpose. It would help to add a PR description. |
As far as I am aware, another reason for wanting a separate diagonal class is the |
Note this works with pyro.distributions. If you'd like this behavior with torch.distributions, feel free to move upstream the Pyro implementations of kl_divergence(Independent,Independent) and kl_divergence(Independent(Normal),MultivariateNormal). @samuela What do you think of Pyro's abbreviated syntax .to_event(), used as Normal(...).to_event(1) = Independent(Normal(...), 1) If there is demand we can move that upstream to torch.distributions.Distribution. I think that would be a better solution than a custom new class that works for only one distribution. |
@fritzo I've started using jax now actually, so I haven't been following this issue much. |
@fritzo Thank you, it is completely straightforward to get that working. |
Is the current def log_prob(self, value):
log_prob = self.base_dist.log_prob(value)
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims) is clearly different from the PDF described in the Tensorflow Distributions document:
I think many people expects the log of PDF when they are calling |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
No description provided.