Skip to content

Commit

Permalink
Expose Gaussian algorithms (#3145)
Browse files Browse the repository at this point in the history
* Move Gaussian algorithms into ops.gaussian

* Add docs
  • Loading branch information
fritzo committed Oct 18, 2022
1 parent aab99f8 commit 9009fee
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 172 deletions.
135 changes: 14 additions & 121 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
gaussian_tensordot,
matrix_and_mvn_to_gaussian,
mvn_to_gaussian,
sequential_gaussian_filter_sample,
sequential_gaussian_tensordot,
)
from pyro.ops.indexing import Vindex
from pyro.ops.special import safe_log
Expand Down Expand Up @@ -159,115 +161,6 @@ def _sequential_index(samples):
return samples.squeeze(-3)[..., :duration, :]


def _sequential_gaussian_tensordot(gaussian):
"""
Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes::
x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
"""
assert isinstance(gaussian, Gaussian)
assert gaussian.dim() % 2 == 0, "dim is not even"
batch_shape = gaussian.batch_shape[:-1]
state_dim = gaussian.dim() // 2
while gaussian.batch_shape[-1] > 1:
time = gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gaussian[..., :even_time]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
contracted = gaussian_tensordot(x, y, state_dim)
if time > even_time:
contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
gaussian = contracted
return gaussian[..., 0]


def _is_subshape(x, y):
return broadcast_shape(x, y) == y


def _sequential_gaussian_filter_sample(init, trans, sample_shape):
"""
Draws a reparameterized sample from a Markov product of Gaussians via
parallel-scan forward-filter backward-sample.
"""
assert isinstance(init, Gaussian)
assert isinstance(trans, Gaussian)
assert trans.dim() == 2 * init.dim()
assert _is_subshape(trans.batch_shape[:-1], init.batch_shape)
state_dim = trans.dim() // 2
device = trans.precision.device
perm = torch.cat(
[
torch.arange(1 * state_dim, 2 * state_dim, device=device),
torch.arange(0 * state_dim, 1 * state_dim, device=device),
torch.arange(2 * state_dim, 3 * state_dim, device=device),
]
)

# Forward filter, similar to _sequential_gaussian_tensordot().
tape = []
shape = trans.batch_shape[:-1] # Note trans may be unbroadcasted.
gaussian = trans
while gaussian.batch_shape[-1] > 1:
time = gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gaussian[..., :even_time]
x_y = even_part.reshape(shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
x = x.event_pad(right=state_dim)
y = y.event_pad(left=state_dim)
joint = (x + y).event_permute(perm)
tape.append(joint)
contracted = joint.marginalize(left=state_dim)
if time > even_time:
contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
gaussian = contracted
gaussian = gaussian[..., 0] + init.event_pad(right=state_dim)

# Backward sample.
shape = sample_shape + init.batch_shape
result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim))
for joint in reversed(tape):
# The following comments demonstrate two example computations, one
# EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be
# a single sampled event of shape (state_dim,).
if joint.batch_shape[-1] == result.size(-2) - 1: # EVEN case.
# Suppose e.g. result = [z0, z2, z4]
cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4]
cond = cond[..., 1:-1, :] # [z0, z2, z2, z4]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4]
sample = joint.condition(cond).rsample() # [z1, z3]
sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0]
result = torch.stack(
[
result, # [z0, z2, z4]
sample, # [z1, z3, 0]
],
dim=-2,
) # [[z0, z1], [z2, z3], [z4, 0]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0]
result = result[..., :-1, :] # [z0, z1, z2, z3, z4]
else: # ODD case.
assert joint.batch_shape[-1] == result.size(-2) - 2
# Suppose e.g. result = [z0, z2, z3]
cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2]
cond = cond[..., 1:-1, :] # [z0, z2]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2]
sample = joint.condition(cond).rsample() # [z1]
sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3]
result = torch.stack(
[
result[..., :-1, :], # [z0, z2]
sample, # [z1, z3]
],
dim=-2,
) # [[z0, z1], [z2, z3]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3]

return result[..., 1:, :] # [z1, z2, z3, ...]


def _sequential_gamma_gaussian_tensordot(gamma_gaussian):
"""
Integrates a GammaGaussian ``x`` whose rightmost batch dimension is time, computes::
Expand Down Expand Up @@ -657,9 +550,9 @@ def expand(self, batch_shape, _instance=None):
new._obs = self._obs
new._trans = self._trans

# To save computation in _sequential_gaussian_tensordot(), we expand
# To save computation in sequential_gaussian_tensordot(), we expand
# only _init, which is applied only after
# _sequential_gaussian_tensordot().
# sequential_gaussian_tensordot().
batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape))
new._init = self._init.expand(batch_shape)

Expand All @@ -679,7 +572,7 @@ def log_prob(self, value):
)

# Eliminate time dimension.
result = _sequential_gaussian_tensordot(result.expand(result.batch_shape))
result = sequential_gaussian_tensordot(result.expand(result.batch_shape))

# Combine initial factor.
result = gaussian_tensordot(self._init, result, dims=self.hidden_dim)
Expand All @@ -695,7 +588,7 @@ def rsample(self, sample_shape=torch.Size()):
left=self.hidden_dim
)
trans = trans.expand(trans.batch_shape[:-1] + (self.duration,))
z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
x = self._obs.left_condition(z).rsample()
return x

Expand All @@ -705,7 +598,7 @@ def rsample_posterior(self, value, sample_shape=torch.Size()):
"""
trans = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)
trans = trans.expand(trans.batch_shape)
z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
return z

def filter(self, value):
Expand All @@ -726,7 +619,7 @@ def filter(self, value):
logp = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)

# Eliminate time dimension.
logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)
Expand Down Expand Up @@ -780,7 +673,7 @@ def conjugate_update(self, other):
logp = new._trans + new._obs.marginalize(right=new.obs_dim).event_pad(
left=new.hidden_dim
)
logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = gaussian_tensordot(new._init, logp, dims=new.hidden_dim)
log_normalizer = logp.event_logsumexp()
new._init = new._init - log_normalizer
Expand Down Expand Up @@ -970,8 +863,8 @@ def expand(self, batch_shape, _instance=None):
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in _sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after _sequential_gaussian_tensordot().
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
Expand Down Expand Up @@ -1380,8 +1273,8 @@ def expand(self, batch_shape, _instance=None):
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in _sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after _sequential_gaussian_tensordot().
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
Expand Down Expand Up @@ -1411,7 +1304,7 @@ def log_prob(self, value):
logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)])

# Eliminate time dimension.
logp = _sequential_gaussian_tensordot(logp)
logp = sequential_gaussian_tensordot(logp)

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)
Expand Down

0 comments on commit 9009fee

Please sign in to comment.