Skip to content

Commit

Permalink
Improve sequential_gaussian_filter_sample() (#3146)
Browse files Browse the repository at this point in the history
* Allow noise to be injected into gaussian rsampling

* Support antithetic sampling of noise, add docs

* Avoid dropping initial time point

* Fix sampling bug

* Add profiling script

* Upate profiler to not require grads

* Reduce memory usage

* Fix device placement and dtype

* Expose matrix_and_gaussian_to_gaussian()
  • Loading branch information
fritzo committed Oct 23, 2022
1 parent c573af1 commit 1098e38
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 50 deletions.
1 change: 1 addition & 0 deletions examples/contrib/forecast/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def preprocess(args):
arrivals = dataset["counts"][:, :, i].sum(-1)
departures = dataset["counts"][:, i, :].sum(-1)
data = torch.stack([arrivals, departures], dim=-1)
print(f"Loaded data of shape {tuple(data.shape)}")

# This simple example uses no covariates, so we will construct a
# zero-element tensor of the correct length as empty covariates.
Expand Down
83 changes: 83 additions & 0 deletions profiler/gaussianhmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import argparse

import torch
from tqdm.auto import tqdm

import pyro.distributions as dist


def random_mvn(batch_shape, dim, requires_grad=False):
rank = dim + dim
loc = torch.randn(batch_shape + (dim,), requires_grad=requires_grad)
cov = torch.randn(batch_shape + (dim, rank))
cov = cov.matmul(cov.transpose(-1, -2))
scale_tril = torch.linalg.cholesky(cov)
scale_tril.requires_grad_(requires_grad)
return dist.MultivariateNormal(loc, scale_tril=scale_tril)


def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")

hidden_dim = args.hidden_dim
obs_dim = args.obs_dim
duration = args.duration
batch_shape = (args.batch_size,)

# Initialize parts.
init_dist = random_mvn(batch_shape, hidden_dim, requires_grad=args.grad)
trans_dist = random_mvn(
batch_shape + (duration,), hidden_dim, requires_grad=args.grad
)
obs_dist = random_mvn(batch_shape + (1,), obs_dim, requires_grad=args.grad)
trans_mat = 0.1 * torch.randn(batch_shape + (duration, hidden_dim, hidden_dim))
obs_mat = torch.randn(batch_shape + (1, hidden_dim, obs_dim))

if args.grad:
# Collect parameters.
params = [
init_dist.loc,
init_dist.scale_tril,
trans_dist.loc,
trans_dist.scale_tril,
obs_dist.loc,
obs_dist.scale_tril,
trans_mat.requires_grad_(),
obs_mat.requires_grad_(),
]

# Build a distribution.
d = dist.GaussianHMM(
init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration
)

for step in tqdm(range(args.num_steps)):
if not args.grad:
# Time forward only.
d.sample()
continue

# Time forward + backward.
x = d.rsample()
grads = torch.autograd.grad(
x.sum(), params, allow_unused=True, retain_graph=True
)
assert not all(g is None for g in grads)
del x


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GaussianHMM profiler")
parser.add_argument("--hidden-dim", type=int, default=4)
parser.add_argument("--obs-dim", type=int, default=4)
parser.add_argument("--duration", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=3)
parser.add_argument("-n", "--num-steps", type=int, default=100)
parser.add_argument("--cuda", action="store_true", default=False)
parser.add_argument("--grad", action="store_true", default=False)
args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ def rsample(self, sample_shape=torch.Size()):
)
trans = trans.expand(trans.batch_shape[:-1] + (self.duration,))
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = z[..., 1:, :] # drop the initial hidden state
x = self._obs.left_condition(z).rsample()
return x

Expand All @@ -599,6 +600,7 @@ def rsample_posterior(self, value, sample_shape=torch.Size()):
trans = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)
trans = trans.expand(trans.batch_shape)
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = z[..., 1:, :] # drop the initial hidden state
return z

def filter(self, value):
Expand Down
130 changes: 87 additions & 43 deletions pyro/ops/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch.distributions.utils import lazy_property
Expand Down Expand Up @@ -148,14 +148,19 @@ def log_density(self, value: torch.Tensor) -> torch.Tensor:
result = (value * result).sum(-1)
return result + self.log_normalizer

def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
def rsample(
self, sample_shape=torch.Size(), noise: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Reparameterized sampler.
"""
P_chol = cholesky(self.precision)
loc = self.info_vec.unsqueeze(-1).cholesky_solve(P_chol).squeeze(-1)
shape = sample_shape + self.batch_shape + (self.dim(), 1)
noise = torch.randn(shape, dtype=loc.dtype, device=loc.device)
if noise is None:
noise = torch.randn(shape, dtype=loc.dtype, device=loc.device)
else:
noise = noise.reshape(shape)
noise = triangular_solve(noise, P_chol, upper=False, transpose=True).squeeze(-1)
sample: torch.Tensor = loc + noise
return sample
Expand Down Expand Up @@ -348,23 +353,29 @@ def left_condition(self, value):
else:
return self.to_gaussian().left_condition(value)

def rsample(self, sample_shape=torch.Size()):
def rsample(
self, sample_shape=torch.Size(), noise: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Reparameterized sampler.
"""
if self.matrix.size(-2) > 0:
raise NotImplementedError
shape = sample_shape + self.batch_shape + self.loc.shape[-1:]
noise = torch.randn(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.loc + noise * self.scale
if noise is None:
noise = torch.randn(shape, dtype=self.loc.dtype, device=self.loc.device)
else:
noise = noise.reshape(shape)
sample: torch.Tensor = self.loc + noise * self.scale
return sample

def to_gaussian(self):
if self._gaussian is None:
mvn = torch.distributions.Independent(
torch.distributions.Normal(self.loc, scale=self.scale), 1
)
y_gaussian = mvn_to_gaussian(mvn)
self._gaussian = _matrix_and_gaussian_to_gaussian(self.matrix, y_gaussian)
self._gaussian = matrix_and_gaussian_to_gaussian(self.matrix, y_gaussian)
return self._gaussian

def expand(self, batch_shape):
Expand Down Expand Up @@ -435,7 +446,17 @@ def mvn_to_gaussian(mvn):
return Gaussian(log_normalizer, info_vec, precision)


def _matrix_and_gaussian_to_gaussian(matrix, y_gaussian):
def matrix_and_gaussian_to_gaussian(
matrix: torch.Tensor, y_gaussian: Gaussian
) -> Gaussian:
"""
Constructs a conditional Gaussian for ``p(y|x)`` where
``y - x @ matrix ~ y_gaussian``.
:param torch.Tensor matrix: A right-acting transformation matrix.
:param Gaussian y_gaussian: A distribution over noise of ``y - x@matrix``.
:rtype: Gaussian
"""
P_yy = y_gaussian.precision
neg_P_xy = matmul(matrix, P_yy)
P_xy = -neg_P_xy
Expand Down Expand Up @@ -480,7 +501,7 @@ def matrix_and_mvn_to_gaussian(matrix, mvn):
return AffineNormal(matrix, mvn.base_dist.loc, mvn.base_dist.scale)

y_gaussian = mvn_to_gaussian(mvn)
result = _matrix_and_gaussian_to_gaussian(matrix, y_gaussian)
result = matrix_and_gaussian_to_gaussian(matrix, y_gaussian)
assert result.batch_shape == batch_shape
assert result.dim() == x_dim + y_dim
return result
Expand Down Expand Up @@ -576,29 +597,39 @@ def sequential_gaussian_tensordot(gaussian: Gaussian) -> Gaussian:
return gaussian[..., 0]


def _is_subshape(x, y):
return broadcast_shape(x, y) == y


def sequential_gaussian_filter_sample(
init: Gaussian, trans: Gaussian, sample_shape: Tuple[int, ...] = ()
init: Gaussian,
trans: Gaussian,
sample_shape: Tuple[int, ...] = (),
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Draws a reparameterized sample from a Markov product of Gaussians via
parallel-scan forward-filter backward-sample.
:param Gaussian init: A Gaussian representing an initial state.
:param Gaussian trans: A Gaussian representing as series of state transitions,
with time as the rightmost batch dimension.
:param tuple sample_shape: An optional batch shape of samples to draw.
:returns: A reparametrized sample.
with time as the rightmost batch dimension. This must have twice the event
dim as ``init``: ``trans.dim() == 2 * init.dim()``.
:param tuple sample_shape: An optional extra shape of samples to draw.
:param torch.Tensor noise: An optional standard white noise tensor of shape
``sample_shape + batch_shape + (duration, state_dim)``, where
``duration = 1 + trans.batch_shape[-1]`` is the number of time points
to be sampled, and ``state_dim = init.dim()`` is the state dimension.
This is useful for computing the mean (pass zeros), varying temperature
(pass scaled noise), and antithetic sampling (pass ``cat([z,-z])``).
:returns: A reparametrized sample of shape
``sample_shape + batch_shape + (duration, state_dim)``.
:rtype: torch.Tensor
"""
assert isinstance(init, Gaussian)
assert isinstance(trans, Gaussian)
assert trans.dim() == 2 * init.dim()
assert _is_subshape(trans.batch_shape[:-1], init.batch_shape)
state_dim = trans.dim() // 2
batch_shape = broadcast_shape(trans.batch_shape[:-1], init.batch_shape)
if init.batch_shape != batch_shape:
init = init.expand(batch_shape)
dtype = trans.precision.dtype
device = trans.precision.device
perm = torch.cat(
[
Expand Down Expand Up @@ -628,9 +659,30 @@ def sequential_gaussian_filter_sample(
gaussian = contracted
gaussian = gaussian[..., 0] + init.event_pad(right=state_dim)

# Generate noise in batch, then allow blocks to be consumed incrementally.
duration = 1 + trans.batch_shape[-1]
shape = torch.Size(sample_shape) + init.batch_shape
result_shape = shape + (duration, state_dim)
noise_stride = shape.numel() * state_dim # noise is consumed in time blocks
noise_position: int = 0
if noise is None:
noise = torch.randn(result_shape, dtype=dtype, device=device)
assert noise.shape == result_shape

def rsample(g: Gaussian, sample_shape: Tuple[int, ...] = ()) -> torch.Tensor:
"""Samples, extracting a time-block of noise."""
nonlocal noise_position
assert noise is not None
numel = torch.Size(sample_shape + g.batch_shape + (g.dim(),)).numel()
assert numel % noise_stride == 0
beg: int = noise_position
end: int = noise_position + numel // noise_stride
assert end <= duration, "too little noise provided"
noise_position = end
return g.rsample(sample_shape, noise=noise[..., beg:end, :])

# Backward sample.
shape = sample_shape + init.batch_shape
result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim))
result = rsample(gaussian, sample_shape).reshape(shape + (2, state_dim))
for joint in reversed(tape):
# The following comments demonstrate two example computations, one
# EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be
Expand All @@ -640,32 +692,24 @@ def sequential_gaussian_filter_sample(
cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4]
cond = cond[..., 1:-1, :] # [z0, z2, z2, z4]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4]
sample = joint.condition(cond).rsample() # [z1, z3]
sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0]
result = torch.stack(
[
result, # [z0, z2, z4]
sample, # [z1, z3, 0]
],
dim=-2,
) # [[z0, z1], [z2, z3], [z4, 0]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0]
result = result[..., :-1, :] # [z0, z1, z2, z3, z4]
sample = rsample(joint.condition(cond)) # [z1, z3]
zipper = result.new_empty(shape + (2 * result.size(-2) - 1, state_dim))
zipper[..., ::2, :] = result # [z0, _, z2, _, z4]
zipper[..., 1::2, :] = sample # [_, z1, _, z3, _]
result = zipper # [z0, z1, z2, z3, z4]
else: # ODD case.
assert joint.batch_shape[-1] == result.size(-2) - 2
# Suppose e.g. result = [z0, z2, z3]
cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2]
cond = cond[..., 1:-1, :] # [z0, z2]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2]
sample = joint.condition(cond).rsample() # [z1]
sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3]
result = torch.stack(
[
result[..., :-1, :], # [z0, z2]
sample, # [z1, z3]
],
dim=-2,
) # [[z0, z1], [z2, z3]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3]

return result[..., 1:, :] # [z1, z2, z3, ...]
sample = rsample(joint.condition(cond)) # [z1]
zipper = result.new_empty(shape + (2 * result.size(-2) - 2, state_dim))
zipper[..., ::2, :] = result[..., :-1, :] # [z0, _, z2, _]
zipper[..., -1, :] = result[..., -1, :] # [_, _, _, z3]
zipper[..., 1:-1:2, :] = sample # [_, z1, _, _]
result = zipper # [z0, z1, z2, z3]

assert noise_position == duration, "too much noise provided"
assert result.shape == result_shape
return result # [z0, z1, z2, ...]
12 changes: 7 additions & 5 deletions tests/ops/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,32 @@
from tests.common import assert_close


def random_gaussian(batch_shape, dim, rank=None):
def random_gaussian(batch_shape, dim, rank=None, *, requires_grad=False):
"""
Generate a random Gaussian for testing.
"""
if rank is None:
rank = dim + dim
log_normalizer = torch.randn(batch_shape)
info_vec = torch.randn(batch_shape + (dim,))
log_normalizer = torch.randn(batch_shape, requires_grad=requires_grad)
info_vec = torch.randn(batch_shape + (dim,), requires_grad=requires_grad)
samples = torch.randn(batch_shape + (dim, rank))
precision = torch.matmul(samples, samples.transpose(-2, -1))
precision.requires_grad_(requires_grad)
result = Gaussian(log_normalizer, info_vec, precision)
assert result.dim() == dim
assert result.batch_shape == batch_shape
return result


def random_mvn(batch_shape, dim):
def random_mvn(batch_shape, dim, *, requires_grad=False):
"""
Generate a random MultivariateNormal distribution for testing.
"""
rank = dim + dim
loc = torch.randn(batch_shape + (dim,))
loc = torch.randn(batch_shape + (dim,), requires_grad=requires_grad)
cov = torch.randn(batch_shape + (dim, rank))
cov = cov.matmul(cov.transpose(-1, -2))
cov.requires_grad_(requires_grad)
return dist.MultivariateNormal(loc, cov)


Expand Down

0 comments on commit 1098e38

Please sign in to comment.