Skip to content

Commit

Permalink
Implement TorchDistribution.mask() for dependent masks (#821)
Browse files Browse the repository at this point in the history
* Implement TorchDistribution.mask() for dependent masks

* Fix typo
  • Loading branch information
fritzo authored and martinjankowiak committed Feb 26, 2018
1 parent 6df6721 commit cbba2e7
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 10 deletions.
13 changes: 8 additions & 5 deletions examples/dmm/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,19 @@ def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
# first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
z_mu, z_sigma = self.trans(z_prev)
# then sample z_t according to dist.Normal(z_mu, z_sigma)
with poutine.scale(None, annealing_factor * mini_batch_mask[:, t - 1:t]):
z_t = pyro.sample("z_%d" % t, dist.Normal(z_mu, z_sigma))
with poutine.scale(None, annealing_factor):
z_t = pyro.sample("z_%d" % t,
dist.Normal(z_mu, z_sigma)
.mask(mini_batch_mask[:, t - 1:t]))

# compute the probabilities that parameterize the bernoulli likelihood
emission_probs_t = self.emitter(z_t)
# the next statement instructs pyro to observe x_t according to the
# bernoulli distribution p(x_t|z_t)
with poutine.scale(None, mini_batch_mask[:, t - 1:t]):
pyro.sample("obs_x_%d" % t, dist.Bernoulli(emission_probs_t),
obs=mini_batch[:, t - 1, :])
pyro.sample("obs_x_%d" % t,
dist.Bernoulli(emission_probs_t)
.mask(mini_batch_mask[:, t - 1:t]),
obs=mini_batch[:, t - 1, :])
# the latent sampled at this time step will be conditioned upon
# in the next time step so keep track of it
z_prev = z_t
Expand Down
68 changes: 63 additions & 5 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pyro.distributions.distribution import Distribution
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import sum_rightmost
from pyro.distributions.util import broadcast_shape, sum_rightmost


class TorchDistributionMixin(Distribution):
Expand Down Expand Up @@ -71,9 +71,20 @@ def reshape(self, sample_shape=torch.Size(), extra_event_dims=0):
:param int extra_event_dims: The number of extra event dimensions that
will be considered dependent.
:return: A reshaped copy of this distribution.
:rtype: :class:`Reshape`
:rtype: :class:`ReshapedDistribution`
"""
return Reshape(self, sample_shape, extra_event_dims)
return ReshapedDistribution(self, sample_shape, extra_event_dims)

def mask(self, mask):
"""
Masks a distribution by a zero-one tensor that is broadcastable to the
distributions ``batch_shape``.
:param Variable mask: A zero-one valued float tensor.
:return: A masked copy of this distribution.
:rtype: :class:`MaskedDistribution`
"""
return MaskedDistribution(self, mask)

def analytic_mean(self):
return self.mean
Expand Down Expand Up @@ -138,7 +149,7 @@ class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin
pass


class Reshape(TorchDistribution):
class ReshapedDistribution(TorchDistribution):
"""
Reshapes a distribution by adding ``sample_shape`` to its total shape
and adding ``extra_event_dims`` to its ``event_shape``.
Expand All @@ -159,7 +170,7 @@ def __init__(self, base_dist, sample_shape=torch.Size(), extra_event_dims=0):
shape = sample_shape + base_dist.batch_shape + base_dist.event_shape
batch_dim = len(shape) - extra_event_dims - len(base_dist.event_shape)
batch_shape, event_shape = shape[:batch_dim], shape[batch_dim:]
super(Reshape, self).__init__(batch_shape, event_shape)
super(ReshapedDistribution, self).__init__(batch_shape, event_shape)

@property
def has_rsample(self):
Expand Down Expand Up @@ -202,3 +213,50 @@ def mean(self):
@property
def variance(self):
return self.base_dist.variance.expand(self.batch_shape + self.event_shape)


class MaskedDistribution(TorchDistribution):
"""
Masks a distribution by a zero-one tensor that is broadcastable to the
distributions ``batch_shape``.
:param Variable mask: A zero-one valued float tensor.
"""
def __init__(self, base_dist, mask):
if broadcast_shape(mask.shape, base_dist.batch_shape) != base_dist.batch_shape:
raise ValueError("Expected mask.shape to be broadcastable to base_dist.batch_shape, "
"actual {} vs {}".format(mask.shape, base_dist.batch_shape))
self.base_dist = base_dist
self._mask = mask
super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape)

@property
def has_rsample(self):
return self.base_dist.has_rsample

@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support

def sample(self, sample_shape=torch.Size()):
return self.base_dist.sample(sample_shape)

def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(sample_shape)

def log_prob(self, value):
return self.base_dist.log_prob(value) * self._mask

def score_parts(self, value):
return self.base_dist.score_parts(value) * self._mask

def enumerate_support(self):
return self.base_dist.enumerate_support()

@property
def mean(self):
return self.base_dist.mean

@property
def variance(self):
return self.base_dist.variance
59 changes: 59 additions & 0 deletions tests/distributions/test_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import absolute_import, division, print_function

import pytest
import torch
from torch.autograd import Variable, variable

from pyro.distributions.torch import Bernoulli
from tests.common import assert_equal


def checker_mask(shape):
mask = variable(0)
for size in shape:
mask = mask.unsqueeze(-1) + Variable(torch.arange(size))
return mask.fmod(2)


@pytest.mark.parametrize('batch_dim,mask_dim',
[(b, m) for b in range(3) for m in range(1 + b)])
@pytest.mark.parametrize('event_dim', [0, 1, 2])
def test_mask(batch_dim, event_dim, mask_dim):
# Construct base distribution.
shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim])
batch_shape = shape[:batch_dim]
mask_shape = batch_shape[batch_dim - mask_dim:]
base_dist = Bernoulli(0.1).reshape(shape, event_dim)

# Construct masked distribution.
mask = checker_mask(mask_shape)
dist = base_dist.mask(mask)

# Check shape.
sample = base_dist.sample()
assert dist.batch_shape == base_dist.batch_shape
assert dist.event_shape == base_dist.event_shape
assert sample.shape == sample.shape
assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape

# Check values.
assert_equal(dist.mean, base_dist.mean)
assert_equal(dist.variance, base_dist.variance)
assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
assert_equal(dist.log_prob(sample), base_dist.log_prob(sample) * mask)
assert_equal(dist.score_parts(sample), base_dist.score_parts(sample) * mask, prec=0)


@pytest.mark.parametrize('batch_shape,mask_shape', [
([], [1]),
([], [2]),
([1], [2]),
([2], [3]),
([2], [1, 1]),
([2, 1], [2]),
])
def test_mask_invalid_shape(batch_shape, mask_shape):
dist = Bernoulli(0.1).reshape(batch_shape)
mask = checker_mask(mask_shape)
with pytest.raises(ValueError):
dist.mask(mask)

0 comments on commit cbba2e7

Please sign in to comment.