# Writing guides using EasyGuide

### Summary:
- For simple black-box guides, try `pyro.contrib.autoguide`.
- For more complex guides, use `pyro.contrib.easyguide`.
- Decorate with `@easy_guide(model)`.
- Select multiple model sites using `group = self.group(match="my_regex")`.
- Guide a group of sites by a single distribution using `group.sample(...)`.
- Inspect concatenated group shape using `group.batch_shape`, `group.event_shape`, etc.
- Use `self.plate(...)` instead of `pyro.plate(...)`.
- To be compatible with subsampling, pass the `event_dim` arg to `pyro.param(...)`.
- To MAP estimate model site "foo", use `foo = self.map_estimate("foo")`.

In [None]:
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.easyguide import easy_guide, numel
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from torch.distributions import constraints

pyro.enable_validation(True)

Consider a time-series model with a slowly-varying continuous latent state and Bernoulli observations with a logistic link function.

In [None]:
def model(batch, full_size, num_time_steps=None):
    batch = list(batch)
    if num_time_steps is None:
        num_time_steps, batch_size = batch.shape
    else:
        batch_size = full_size
    drift = pyro.sample("drift", dist.LogNormal(-1, 1))
    with pyro.plate("data", full_size, subsample_size=batch_size):
        z = 0.
        for t in range(num_time_steps):
            z = pyro.sample("state_{}".format(t),
                            dist.Normal(z, 10. if t == 0 else drift))
            batch[t] = pyro.sample("obs_{}".format(t),
                                   dist.Bernoulli(logits=z),
                                   obs=batch[t])
    return torch.stack(batch, dim=-1)

In [None]:
full_size = 100
num_time_steps = 52
data = model([None] * num_time_steps, full_size, num_time_steps)
assert data.shape == (full_size, num_time_steps), data.shape

Consider a possible guide for this model where we point-estimate the `drift` parameter using a `Delta` distribution, and then model local time series using shared uncertainty but local means, using a `LowRankMultivariateNormal` distribution. There is a single global sample site which we can model with a `param` and `sample` statement. Then we sample a global pair of uncertainty parameters `cov_diag` and `cov_factor`. Next we sample a local `loc` parameter using `pyro.param(..., event_dim=...)` and an auxiliary sample site. Finally we unpack that site into one element per time series.

In [None]:
def guide(batch, full_size):
    batch_size, num_time_steps = batch.shape
    full_shape = (full_size, num_time_steps)

    # MAP estimate the drift.
    drift_loc = pyro.param("drift_loc", lambda: torch.tensor(0.1),
                           constraint=constraints.positive)
    drift = pyro.sample("drift", dist.Delta(drift_loc))

    # Model local states using shared uncertainty + local mean.
    rank = 3
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(full_shape, 0.01))
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.full(full_shape + (rank,), 0.01))
    with pyro.plate("data", full_size, subsample=batch):
        # Sample local mean.
        loc = pyro.param("state_loc",
                         lambda: torch.full(full_shape, 0.5),
                         event_dim=1)
        states = pyro.sample("states",
                             dist.LowRankMultivariateNormal(loc, cov_diag, cov_factor),
                             infer={"is_auxiliary": True})
        # Unpack the joint states into one sample site per time step.
        for t in range(num_time_steps):
            group.sample("state_{}".format(t), dist.Delta(states[:, t]))

In [None]:
def train(guide, num_epochs=100, batch_size=10):
    pyro.get_param_store().clear()
    pyro.set_rng_seed(123456789)
    svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
    for epoch in range(num_epochs):
        pos = 0
        epoch_loss = 0
        while pos < len(data):
            batch = data[pos:pos + batch_size]
            epoch_loss += svi.step(batch, full_size=len(data))
        print("epoch {} loss = {}".format(epoch, epoch_loss))

In [None]:
train(guide)

In [None]:
@easy_guide(model)
def guide(self, batch, full_size):
    # MAP estimate the drift.
    drift = self.map_estimate("drift_loc")

    # Model local states using shared uncertainty + local mean.
    rank = 3
    group = self.group(match="state_.*")
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01))
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.full(group.event_shape + (rank,), 0.01))
    with pyro.plate("data", full_size, subsample=batch):
        # Sample local mean.
        loc = pyro.param("state_loc",
                         lambda: torch.full(group.full_shape + group.event_shape, 0.5),
                         event_dim=1)
        # Automatically sample the joint latent, then unpack and replay model sites.
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_diag, cov_factor))

In [None]:
train(guide)

Next consider an amortized guide.

In [None]:
@easy_guide(model)
def guide(self, batch, full_size):
    drift = self.map_estimate("drift_loc")

    group = self.group(match="state_.*")
    if not hasattr(self, "nn"):
        self.nn = torch.nn.Linear(numel(group.event_shape))
        self.nn.weight.data.fill_(2.0)
        self.nn.bias.data.fill_(-1.)
    loc = self.nn(batch)
    rank = 3
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01))
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.full(group.event_shape + (rank,), 0.01))
    with pyro.plate("data", full_size, subsample=batch):
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_diag, cov_factor))

In [None]:
train(guide)