Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelized DiscreteHMM.sample() #3053

Merged
merged 6 commits into from
Mar 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 84 additions & 11 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F

from pyro.ops.gamma_gaussian import (
GammaGaussian,
Expand Down Expand Up @@ -82,6 +83,82 @@ def _sequential_logmatmulexp(logits):
return logits.squeeze(-3)


def _markov_index(x, y):
"""
Join ends of two Markov paths.
"""
y = Vindex(y.unsqueeze(-2))[..., x[..., -1:, :]]
return torch.cat([x, y], -2)


def _sequential_index(samples):
"""
For a tensor ``samples`` whose time dimension is -2 and state dimension
is -1, compute Markov paths by sequential indexing.

For example, for ``samples`` with 3 states and time duration 5::

tensor([[0, 1, 1],
[1, 0, 2],
[2, 1, 0],
[0, 2, 1],
[1, 1, 0]])

computed paths are::

tensor([[0, 1, 1],
[1, 0, 0],
[1, 2, 2],
[2, 1, 1],
[0, 1, 1]])

# path for a 0th state
#
# 0 1 1
# |
# 1 0 2
# \
# 2 1 0
# |
# 0 2 1
# \
# 1 1 0
#
# paths for 1st and 2nd states
#
# 0 1 1
# |/
# 1 0 2
# /
# 2 1 0
# \
# \
# 0 2 1
# /
# 1 1 0
"""
# new Markov time dimension at -2
samples = samples.unsqueeze(-2)
batch_shape = samples.shape[:-3]
state_dim = samples.size(-1)
duration = samples.size(-3)
while samples.size(-3) > 1:
time = samples.size(-3)
even_time = time // 2 * 2
even_part = samples[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, -1, state_dim))
x, y = x_y.unbind(-3)
contracted = _markov_index(x, y)
if time > even_time:
padded = F.pad(
input=samples[..., -1:, :, :],
pad=(0, 0, 0, contracted.size(-2) // 2),
)
contracted = torch.cat((contracted, padded), dim=-3)
samples = contracted
return samples.squeeze(-3)[..., :duration, :]


def _sequential_gaussian_tensordot(gaussian):
"""
Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes::
Expand Down Expand Up @@ -276,7 +353,7 @@ class DiscreteHMM(HiddenMarkovModel):
distribution.

This uses [1] to parallelize over time, achieving O(log(time)) parallel
complexity for computing :meth:`log_prob` and :meth:`filter`.
complexity for computing :meth:`log_prob`, :meth:`filter`, and :meth:`sample`.

The event_shape of this distribution includes time on the left::

Expand All @@ -292,10 +369,6 @@ class DiscreteHMM(HiddenMarkovModel):
# homogeneous + homogeneous case:
event_shape = (1,) + observation_dist.event_shape

The :meth:`sample` method is sequential (not parallized), slow, and memory
inefficient. It is intended for data generation only and is not recommended
during inference.

**References:**

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
Expand Down Expand Up @@ -441,13 +514,13 @@ def sample(self, sample_shape=torch.Size()):
x = Categorical(logits=init_logits).sample()

# Sample hidden states over time.
trans_shape = self.batch_shape + (self.duration, S, S)
trans_shape = (
torch.Size(sample_shape) + self.batch_shape + (self.duration, S, S)
)
trans_logits = self.transition_logits.expand(trans_shape)
xs = []
for t in range(self.duration):
x = Categorical(logits=Vindex(trans_logits)[..., t, x, :]).sample()
xs.append(x)
x = torch.stack(xs, dim=-1)
xs = Categorical(logits=trans_logits).sample()
xs = _sequential_index(xs)
x = Vindex(xs)[..., :, x]

# Sample observations conditioned on hidden states.
# Note the simple sample-then-slice approach here generalizes to all
Expand Down