-
-
Notifications
You must be signed in to change notification settings - Fork 987
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement NanMaskedNormal, NanMaskedMultivariateNormal (#3116)
* Implement NanMaskedNormal, NanMaskedMultivariateNormal * Fix test * Add test for fully-unobserved data
- Loading branch information
Showing
4 changed files
with
216 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
from .torch import MultivariateNormal, Normal | ||
|
||
|
||
class NanMaskedNormal(Normal): | ||
""" | ||
Wrapper around :class:`~pyro.distributions.Normal` to allow partially | ||
observed data as specified by NAN elements in :meth:`log_prob`; the | ||
``log_prob`` of these elements will be zero. This is useful for likelihoods | ||
with missing data. | ||
Example:: | ||
from math import nan | ||
data = torch.tensor([0.5, 0.1, nan, 0.9]) | ||
with pyro.plate("data", len(data)): | ||
pyro.sample("obs", NanMaskedNormal(0, 1), obs=data) | ||
""" | ||
|
||
def log_prob(self, value: torch.Tensor) -> torch.Tensor: | ||
ok = value.isfinite() | ||
if ok.all(): | ||
return super().log_prob(value) | ||
|
||
# Broadcast all tensors. | ||
value, ok, loc, scale = torch.broadcast_tensors(value, ok, self.loc, self.scale) | ||
result = value.new_zeros(value.shape) | ||
|
||
# Evaluate ok elements. | ||
if ok.any(): | ||
marginal = Normal(loc[ok], scale[ok], validate_args=False) | ||
result[ok] = marginal.log_prob(value[ok]) | ||
return result | ||
|
||
|
||
class NanMaskedMultivariateNormal(MultivariateNormal): | ||
""" | ||
Wrapper around :class:`~pyro.distributions.MultivariateNormal` to allow | ||
partially observed data as specified by NAN elements in the argument to | ||
:meth:`log_prob`. The ``log_prob`` of these events will marginalize over | ||
the NAN elements. This is useful for likelihoods with missing data. | ||
Example:: | ||
from math import nan | ||
data = torch.tensor([ | ||
[0.1, 0.2, 3.4], | ||
[0.5, 0.1, nan], | ||
[0.6, nan, nan], | ||
[nan, 0.5, nan], | ||
[nan, nan, nan], | ||
]) | ||
with pyro.plate("data", len(data)): | ||
pyro.sample( | ||
"obs", | ||
NanMaskedMultivariateNormal(torch.zeros(3), torch.eye(3)), | ||
obs=data, | ||
) | ||
""" | ||
|
||
def log_prob(self, value: torch.Tensor) -> torch.Tensor: | ||
ok = value.isfinite() | ||
if ok.all(): | ||
return super().log_prob(value) | ||
|
||
# Broadcast all tensors. This might waste some computation by eagerly | ||
# broadcasting, but the optimal implementation is quite complex. | ||
value, ok, loc = torch.broadcast_tensors(value, ok, self.loc) | ||
cov = self.covariance_matrix.expand(loc.shape + loc.shape[-1:]) | ||
|
||
# Flatten. | ||
result_shape = value.shape[:-1] | ||
n = result_shape.numel() | ||
p = value.shape[-1] | ||
value = value.reshape(n, p) | ||
ok = ok.reshape(n, p) | ||
loc = loc.reshape(n, p) | ||
cov = cov.reshape(n, p, p) | ||
result = value.new_zeros(n) | ||
|
||
# Evaluate ok elements. | ||
for pattern in sorted(set(map(tuple, ok.tolist()))): | ||
if not any(pattern): | ||
continue | ||
# Marginalize out NAN elements. | ||
col_mask = torch.tensor(pattern) | ||
row_mask = (ok == col_mask).all(-1) | ||
ok_value = value[row_mask][:, col_mask] | ||
ok_loc = loc[row_mask][:, col_mask] | ||
ok_cov = cov[row_mask][:, col_mask][:, :, col_mask] | ||
marginal = MultivariateNormal(ok_loc, ok_cov, validate_args=False) | ||
result[row_mask] = marginal.log_prob(ok_value) | ||
|
||
# Unflatten. | ||
return result.reshape(result_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import math | ||
|
||
import pytest | ||
import torch | ||
|
||
import pyro | ||
import pyro.distributions as dist | ||
from pyro.infer import SVI, Trace_ELBO | ||
from pyro.infer.autoguide import AutoNormal | ||
from pyro.optim import Adam | ||
from tests.common import assert_close | ||
|
||
|
||
@pytest.mark.parametrize("batch_shape", [(), (40,), (11, 9)], ids=str) | ||
def test_normal(batch_shape): | ||
# Test on full data | ||
data = torch.randn(batch_shape) | ||
loc = torch.randn(batch_shape).requires_grad_() | ||
scale = torch.randn(batch_shape).exp().requires_grad_() | ||
d = dist.NanMaskedNormal(loc, scale) | ||
d2 = dist.Normal(loc, scale) | ||
actual = d.log_prob(data) | ||
expected = d2.log_prob(data) | ||
assert_close(actual, expected) | ||
|
||
# Test on partial data. | ||
ok = torch.rand(batch_shape) < 0.5 | ||
data[~ok] = math.nan | ||
actual = d.log_prob(data) | ||
assert actual.shape == expected.shape | ||
assert actual.isfinite().all() | ||
loc_grad, scale_grad = torch.autograd.grad(actual.sum(), [loc, scale]) | ||
assert loc_grad.isfinite().all() | ||
assert scale_grad.isfinite().all() | ||
|
||
# Check identity on fully observed and fully unobserved rows. | ||
assert_close(actual[ok], expected[ok]) | ||
assert_close(actual[~ok], torch.zeros_like(actual[~ok])) | ||
|
||
|
||
@pytest.mark.parametrize("batch_shape", [(), (40,), (11, 9)], ids=str) | ||
@pytest.mark.parametrize("p", [1, 2, 3, 10], ids=str) | ||
def test_multivariate_normal(batch_shape, p): | ||
# Test on full data | ||
data = torch.randn(batch_shape + (p,)) | ||
loc = torch.randn(batch_shape + (p,)).requires_grad_() | ||
scale_tril = torch.randn(batch_shape + (p, p)) | ||
scale_tril.tril_() | ||
scale_tril.diagonal(dim1=-2, dim2=-1).exp_() | ||
scale_tril.requires_grad_() | ||
d = dist.NanMaskedMultivariateNormal(loc, scale_tril=scale_tril) | ||
d2 = dist.MultivariateNormal(loc, scale_tril=scale_tril) | ||
actual = d.log_prob(data) | ||
expected = d2.log_prob(data) | ||
assert_close(actual, expected) | ||
|
||
# Test on partial data. | ||
ok = torch.rand(batch_shape + (p,)) < 0.5 | ||
data[~ok] = math.nan | ||
actual = d.log_prob(data) | ||
assert actual.shape == expected.shape | ||
assert actual.isfinite().all() | ||
loc_grad, scale_tril_grad = torch.autograd.grad(actual.sum(), [loc, scale_tril]) | ||
assert loc_grad.isfinite().all() | ||
assert scale_tril_grad.isfinite().all() | ||
|
||
# Check identity on fully observed and fully unobserved rows. | ||
observed = ok.all(-1) | ||
assert_close(actual[observed], expected[observed]) | ||
unobserved = ~ok.any(-1) | ||
assert_close(actual[unobserved], torch.zeros_like(actual[unobserved])) | ||
|
||
|
||
def test_multivariate_normal_model(): | ||
def model(data): | ||
loc = pyro.sample("loc", dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1)) | ||
scale_tril = torch.eye(3) | ||
with pyro.plate("data", len(data)): | ||
pyro.sample( | ||
"obs", | ||
dist.NanMaskedMultivariateNormal(loc, scale_tril=scale_tril), | ||
obs=data, | ||
) | ||
|
||
data = torch.randn(100, 3) | ||
ok = torch.rand(100, 3) < 0.5 | ||
assert 100 < ok.long().sum() < 200, "weak test" | ||
data[~ok] = math.nan | ||
|
||
guide = AutoNormal(model) | ||
svi = SVI(model, guide, Adam({"lr": 1e-4}), Trace_ELBO()) | ||
for step in range(3): | ||
loss = svi.step(data) | ||
assert math.isfinite(loss) |