Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultivariateNormalDiag distribution #11178

Closed
wants to merge 3 commits into from

Conversation

samuela
Copy link
Contributor

@samuela samuela commented Sep 1, 2018

No description provided.

@vishwakftw
Copy link
Contributor

cc: @fritzo @alicanb @fehiepsi

@alicanb
Copy link
Collaborator

alicanb commented Sep 2, 2018

I'm not sure we actually need this. This can be done with Normal only (+ Independent if you're strict about event_shape)

@fritzo
Copy link
Collaborator

fritzo commented Sep 2, 2018

@alicanb is correct, this can be accomplished via Independent(Normal(loc, scale_diag), 1). Maybe that should be an example in the Normal or Independent docs?

@alicanb
Copy link
Collaborator

alicanb commented Sep 2, 2018

@fritzo there is an example in Independent docs. (written by you actually 😄 )

@samuela
Copy link
Contributor Author

samuela commented Sep 2, 2018

I wasn't aware of the Independent(Normal(loc, scale_diag), 1) trick. I use diagonal normals quite a lot, so it would be nice if there was a direct way to construct them that was a little bit more explicit in intent than Independent(Normal(loc, scale_diag), 1). But I understand that there's a tradeoff here.

Are there any performance benefits to having a dedicated implementation?

@zou3519
Copy link
Contributor

zou3519 commented Sep 25, 2018

Have you tried benchmarking @samuela?

@samuela
Copy link
Contributor Author

samuela commented Sep 25, 2018

@zou3519 I have not. I assumed that this PR was sort of doomed based on the discussion.

@fritzo
Copy link
Collaborator

fritzo commented Sep 25, 2018

I think there will be negligible performance benefit, since Independent simply calls its base dist

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def log_prob(self, value):
        log_prob = self.base_dist.log_prob(value)
        return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)

Also IMO it is cleaner to provide a general Independent wrapper to create diagonal distributions of any base type, rather than implement DiagNormal, DiagGamma, DiagCauchy, DiagBeta etc. If you find Independent(Normal(...), ...) cumbersome, we could add a Normal.independent(...) method as we have in Pyro.

@samuela
Copy link
Contributor Author

samuela commented Sep 25, 2018

@fritzo I find myself using a lot of diagonal normal distributions in my work. For me it's just cleaner and communicates intent more clearly to write MultivariateNormalDiag instead of Independent(Normal(...), 1). Right now I've just written a function MultivariateNormalDiag that calls Independent(Normal(...), 1).

raise ValueError("Incompatible batch shapes: loc {}, scale_diag {}"
.format(loc.shape, scale_diag.shape))
batch_shape = self.loc.shape[:-1]
super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes. will fix

@fritzo
Copy link
Collaborator

fritzo commented Feb 4, 2019

Can we replace this complex PR in favor of @samuela's simple wrapper?

def MultivariateNormalDiag(loc, scale_diag):
    if loc.dim() < 1:
        raise ValueError("loc must be at least one-dimensional.")
    return Independent(Normal(loc, scale_diag), 1)

EDIT sorry if I misunderstand the purpose. It would help to add a PR description.

@zdevito zdevito removed their request for review February 13, 2019 01:22
@gchanan gchanan removed their request for review February 28, 2019 16:20
@JakobHavtorn
Copy link

As far as I am aware, another reason for wanting a separate diagonal class is the torch.distributions.kl_divergence raises a NotImplementedError when called on a diagonal Normal defined with Independent(Normal(loc, scale_diag), 1). It works fine with MultivariateNormal but this can be memory intensive.

@fritzo
Copy link
Collaborator

fritzo commented Jan 18, 2020

torch.distributions.kl_divergence raises a NotImplementedError

Note this works with pyro.distributions. If you'd like this behavior with torch.distributions, feel free to move upstream the Pyro implementations of kl_divergence(Independent,Independent) and kl_divergence(Independent(Normal),MultivariateNormal).

@samuela What do you think of Pyro's abbreviated syntax .to_event(), used as

Normal(...).to_event(1) = Independent(Normal(...), 1)

If there is demand we can move that upstream to torch.distributions.Distribution. I think that would be a better solution than a custom new class that works for only one distribution.

@samuela
Copy link
Contributor Author

samuela commented Jan 18, 2020

@fritzo I've started using jax now actually, so I haven't been following this issue much. .to_event(...) looks interesting although I would worry about offering to ways to accomplish the same thing.

@JakobHavtorn
Copy link

JakobHavtorn commented Jan 18, 2020

@fritzo Thank you, it is completely straightforward to get that working.

@lqf96
Copy link
Contributor

lqf96 commented Sep 26, 2020

Is the current log_prob behavior something we want? Because the following implementation:

def log_prob(self, value):
        log_prob = self.base_dist.log_prob(value)
        return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)

is clearly different from the PDF described in the Tensorflow Distributions document:

pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
y = inv(scale) @ (x - loc),
Z = (2 pi)**(0.5 k) |det(scale)|,

I think many people expects the log of PDF when they are calling log_prob, but clearly PyTorch is not returning the log of PDF for MultivariateNormalDiag on Independent(Normal(loc, scale), 1).

@pytorchbot
Copy link
Collaborator

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
Stale pull requests will automatically be closed 30 days after being marked Stale

@pytorchbot pytorchbot added Stale and removed Stale labels Apr 12, 2022
@github-actions
Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 11, 2022
@github-actions github-actions bot closed this Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.