# Writing guides using EasyGuide

### Summary:
- For simple black-box guides, try `pyro.contrib.autoguide`.
- For more complex guides, use `pyro.contrib.easyguide`.
- Decorate with `@easy_guide(model)`.
- Select multiple model sites using `group = self.group(match="my_regex")`.
- Guide a group of sites by a single distribution using `group.sample(...)`.
- Inspect concatenated group shape using `group.batch_shape`, `group.event_shape`, etc.
- Use `self.param(...)` instead of `pyro.param(...)`.
- To be compatible with subsampling, pass the `event_dim` arg to `pyro.param(...)`.
- MAP estimate model site "foo" using `foo = self.map_estimate("foo")`.

In [None]:
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.easyguide import easy_guide, numel

Consider a time-series model with a slowly-varying continuous latent state and Bernoulli observations with a logistic link function.

In [None]:
def model(batch, full_size):
    batch_size, num_time_steps = batch.shape
    drift = pyro.sample("drift", dist.LogNormal(-1, 1))
    with pyro.plate("data", full_size, subsample=batch):
        z = 0.
        for t in range(1, num_time_steps):
            z = pyro.sample("state_{}".format(t),
                            dist.Normal(z, 10. if t == 0 else drift))
            batch[t] = pyro.sample("obs_{}".format(t),
                                   dist.Bernoulli(logits=z),
                                   obs=batch[t])
    return batch

In [None]:
@easy_guide(model)
def guide(self, batch, full_size):
    drift = self.map_estimate("drift_loc")

    group = self.group(match="state_.*")
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01))
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.full(group.event_shape + (rank,), 0.01))
    with pyro.plate("data", full_size, subsample=batch):
        loc = pyro.param("state_loc",
                         lambda: torch.full(group.full_shape + group.event_shape, 0.5),
                         event_dim=1)
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_diag, cov_factor))

Next consider an amortized guide.

In [None]:
@easy_guide(model)
def guide(self, batch, full_size):
    drift = self.map_estimate("drift_loc")

    group = self.group(match="state_.*")
    if not hasattr(self, "nn"):
        self.nn = torch.nn.Linear(numel(group.event_shape))
        self.nn.weight.data.fill_(2.0)
        self.nn.bias.data.fill_(-1.)
    loc = self.nn(batch)
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01))
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.full(group.event_shape + (rank,), 0.01))
    with pyro.plate("data", full_size, subsample=batch):
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_diag, cov_factor))