Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement NanMaskedNormal, NanMaskedMultivariateNormal #3116

Merged
merged 3 commits into from
Jul 10, 2022
Merged

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jul 9, 2022

This implements two distributions to serve as likelihoods for partially observed data, where unobserved elements are specified as NAN values. This is new functionality beyond pyro.mask() and Distribution.mask() in that it allows NAN values within an event of MultivariateNormal; in this case we can analytically marginalize out the missing value. The NanMaskedNormal is similar to Normal.mask(...), but I've included it for easier compatibility with the nontrivial NanMaskedMultivariateNormal.

My motivating example is a Bayesian multivariate linear regression model with learned multivariate noise distribution and partially observed response as specified in a pandas dataframe. Each of the response columns is differently partially observed.

Tested

  • unit test of NanMaskedNormal
  • unit test of NanMaskedMultivariateNormal
  • end-to-end smoke test of NanMaskedMultivariateNormal

@fehiepsi
Copy link
Member

fehiepsi commented Jul 9, 2022

Nice, I remember that this is requested by many forum users.

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. obviously there are various ways the computation could be sped-up in different regimes but since this is probably most useful in the relatively low dimensional setting anyway...

result = value.new_zeros(n)

# Evaluate ok elements.
for pattern in sorted(set(map(tuple, ok.tolist()))):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i thought you were computing one big marginalized covariance with 0s/1s where appropriate so that everything could be vectorized (no for loop)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄 that's beyond my linear algebra skills / patience. In practice I'm working with 3 columns so there are at most 7 patterns.

ok_value = value[row_mask][:, col_mask]
ok_loc = loc[row_mask][:, col_mask]
ok_cov = cov[row_mask][:, col_mask][:, :, col_mask]
marginal = MultivariateNormal(ok_loc, ok_cov, validate_args=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do these invocation not need covariance_matrix=?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess one nice thing about this pattern is that you don't need to worry about factors of log 2pi explicitly...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

covariance_matrix is the default first argument, so no kwarg is necessary.

@martinjankowiak martinjankowiak merged commit 38facc1 into dev Jul 10, 2022
@martinjankowiak martinjankowiak deleted the nan-masked branch July 10, 2022 17:08
OlaRonning pushed a commit to aleatory-science/pyro that referenced this pull request Aug 2, 2022
* Implement NanMaskedNormal, NanMaskedMultivariateNormal

* Fix test

* Add test for fully-unobserved data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants