Skip to content

Commit

Permalink
Implement NanMaskedNormal, NanMaskedMultivariateNormal (#3116)
Browse files Browse the repository at this point in the history
* Implement NanMaskedNormal, NanMaskedMultivariateNormal

* Fix test

* Add test for fully-unobserved data
  • Loading branch information
fritzo committed Jul 10, 2022
1 parent 66defe8 commit 38facc1
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 3 deletions.
14 changes: 14 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ MultivariateStudentT
:undoc-members:
:show-inheritance:

NanMaskedNormal
---------------
.. autoclass:: pyro.distributions.NanMaskedNormal
:members:
:undoc-members:
:show-inheritance:

NanMaskedMultivariateNormal
---------------------------
.. autoclass:: pyro.distributions.NanMaskedMultivariateNormal
:members:
:undoc-members:
:show-inheritance:

OMTMultivariateNormal
---------------------
.. autoclass:: pyro.distributions.OMTMultivariateNormal
Expand Down
9 changes: 6 additions & 3 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pyro.distributions.logistic import Logistic, SkewLogistic
from pyro.distributions.mixture import MaskedMixture
from pyro.distributions.multivariate_studentt import MultivariateStudentT
from pyro.distributions.nanmasked import NanMaskedMultivariateNormal, NanMaskedNormal
from pyro.distributions.omt_mvn import OMTMultivariateNormal
from pyro.distributions.one_one_matching import OneOneMatching
from pyro.distributions.one_two_matching import OneTwoMatching
Expand Down Expand Up @@ -92,9 +93,9 @@
from . import constraints, kl, transforms

__all__ = [
"AVFMultivariateNormal",
"AffineBeta",
"AsymmetricLaplace",
"AVFMultivariateNormal",
"BetaBinomial",
"CoalescentRateLikelihood",
"CoalescentTimes",
Expand Down Expand Up @@ -124,13 +125,15 @@
"LKJ",
"LKJCorrCholesky",
"LinearHMM",
"Logistic",
"LogNormalNegativeBinomial",
"Logistic",
"MaskedDistribution",
"MaskedMixture",
"MixtureOfDiagNormals",
"MixtureOfDiagNormalsSharedCovariance",
"MultivariateStudentT",
"NanMaskedMultivariateNormal",
"NanMaskedNormal",
"OMTMultivariateNormal",
"OneOneMatching",
"OneTwoMatching",
Expand All @@ -142,8 +145,8 @@
"SineBivariateVonMises",
"SineSkewed",
"SkewLogistic",
"SoftLaplace",
"SoftAsymmetricLaplace",
"SoftLaplace",
"SpanningTree",
"Stable",
"TorchDistribution",
Expand Down
99 changes: 99 additions & 0 deletions pyro/distributions/nanmasked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

from .torch import MultivariateNormal, Normal


class NanMaskedNormal(Normal):
"""
Wrapper around :class:`~pyro.distributions.Normal` to allow partially
observed data as specified by NAN elements in :meth:`log_prob`; the
``log_prob`` of these elements will be zero. This is useful for likelihoods
with missing data.
Example::
from math import nan
data = torch.tensor([0.5, 0.1, nan, 0.9])
with pyro.plate("data", len(data)):
pyro.sample("obs", NanMaskedNormal(0, 1), obs=data)
"""

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
ok = value.isfinite()
if ok.all():
return super().log_prob(value)

# Broadcast all tensors.
value, ok, loc, scale = torch.broadcast_tensors(value, ok, self.loc, self.scale)
result = value.new_zeros(value.shape)

# Evaluate ok elements.
if ok.any():
marginal = Normal(loc[ok], scale[ok], validate_args=False)
result[ok] = marginal.log_prob(value[ok])
return result


class NanMaskedMultivariateNormal(MultivariateNormal):
"""
Wrapper around :class:`~pyro.distributions.MultivariateNormal` to allow
partially observed data as specified by NAN elements in the argument to
:meth:`log_prob`. The ``log_prob`` of these events will marginalize over
the NAN elements. This is useful for likelihoods with missing data.
Example::
from math import nan
data = torch.tensor([
[0.1, 0.2, 3.4],
[0.5, 0.1, nan],
[0.6, nan, nan],
[nan, 0.5, nan],
[nan, nan, nan],
])
with pyro.plate("data", len(data)):
pyro.sample(
"obs",
NanMaskedMultivariateNormal(torch.zeros(3), torch.eye(3)),
obs=data,
)
"""

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
ok = value.isfinite()
if ok.all():
return super().log_prob(value)

# Broadcast all tensors. This might waste some computation by eagerly
# broadcasting, but the optimal implementation is quite complex.
value, ok, loc = torch.broadcast_tensors(value, ok, self.loc)
cov = self.covariance_matrix.expand(loc.shape + loc.shape[-1:])

# Flatten.
result_shape = value.shape[:-1]
n = result_shape.numel()
p = value.shape[-1]
value = value.reshape(n, p)
ok = ok.reshape(n, p)
loc = loc.reshape(n, p)
cov = cov.reshape(n, p, p)
result = value.new_zeros(n)

# Evaluate ok elements.
for pattern in sorted(set(map(tuple, ok.tolist()))):
if not any(pattern):
continue
# Marginalize out NAN elements.
col_mask = torch.tensor(pattern)
row_mask = (ok == col_mask).all(-1)
ok_value = value[row_mask][:, col_mask]
ok_loc = loc[row_mask][:, col_mask]
ok_cov = cov[row_mask][:, col_mask][:, :, col_mask]
marginal = MultivariateNormal(ok_loc, ok_cov, validate_args=False)
result[row_mask] = marginal.log_prob(ok_value)

# Unflatten.
return result.reshape(result_shape)
97 changes: 97 additions & 0 deletions tests/distributions/test_nanmasked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import pytest
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam
from tests.common import assert_close


@pytest.mark.parametrize("batch_shape", [(), (40,), (11, 9)], ids=str)
def test_normal(batch_shape):
# Test on full data
data = torch.randn(batch_shape)
loc = torch.randn(batch_shape).requires_grad_()
scale = torch.randn(batch_shape).exp().requires_grad_()
d = dist.NanMaskedNormal(loc, scale)
d2 = dist.Normal(loc, scale)
actual = d.log_prob(data)
expected = d2.log_prob(data)
assert_close(actual, expected)

# Test on partial data.
ok = torch.rand(batch_shape) < 0.5
data[~ok] = math.nan
actual = d.log_prob(data)
assert actual.shape == expected.shape
assert actual.isfinite().all()
loc_grad, scale_grad = torch.autograd.grad(actual.sum(), [loc, scale])
assert loc_grad.isfinite().all()
assert scale_grad.isfinite().all()

# Check identity on fully observed and fully unobserved rows.
assert_close(actual[ok], expected[ok])
assert_close(actual[~ok], torch.zeros_like(actual[~ok]))


@pytest.mark.parametrize("batch_shape", [(), (40,), (11, 9)], ids=str)
@pytest.mark.parametrize("p", [1, 2, 3, 10], ids=str)
def test_multivariate_normal(batch_shape, p):
# Test on full data
data = torch.randn(batch_shape + (p,))
loc = torch.randn(batch_shape + (p,)).requires_grad_()
scale_tril = torch.randn(batch_shape + (p, p))
scale_tril.tril_()
scale_tril.diagonal(dim1=-2, dim2=-1).exp_()
scale_tril.requires_grad_()
d = dist.NanMaskedMultivariateNormal(loc, scale_tril=scale_tril)
d2 = dist.MultivariateNormal(loc, scale_tril=scale_tril)
actual = d.log_prob(data)
expected = d2.log_prob(data)
assert_close(actual, expected)

# Test on partial data.
ok = torch.rand(batch_shape + (p,)) < 0.5
data[~ok] = math.nan
actual = d.log_prob(data)
assert actual.shape == expected.shape
assert actual.isfinite().all()
loc_grad, scale_tril_grad = torch.autograd.grad(actual.sum(), [loc, scale_tril])
assert loc_grad.isfinite().all()
assert scale_tril_grad.isfinite().all()

# Check identity on fully observed and fully unobserved rows.
observed = ok.all(-1)
assert_close(actual[observed], expected[observed])
unobserved = ~ok.any(-1)
assert_close(actual[unobserved], torch.zeros_like(actual[unobserved]))


def test_multivariate_normal_model():
def model(data):
loc = pyro.sample("loc", dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1))
scale_tril = torch.eye(3)
with pyro.plate("data", len(data)):
pyro.sample(
"obs",
dist.NanMaskedMultivariateNormal(loc, scale_tril=scale_tril),
obs=data,
)

data = torch.randn(100, 3)
ok = torch.rand(100, 3) < 0.5
assert 100 < ok.long().sum() < 200, "weak test"
data[~ok] = math.nan

guide = AutoNormal(model)
svi = SVI(model, guide, Adam({"lr": 1e-4}), Trace_ELBO())
for step in range(3):
loss = svi.step(data)
assert math.isfinite(loss)

0 comments on commit 38facc1

Please sign in to comment.