# High-dimensional Bayesian workflow: autoguides and reparametrizers

This tutorial describes a workflow for incrementally building pipelines to analyze high-dimensional data in Pyro. This workflow has evolved over a few years of applying Pyro to models with $10^5$ or more latent variables. We build on [Gelman et al. (2020)](https://arxiv.org/abs/2011.01808)'s concept of *Bayesian workflow*, and focus on aspects particular to high-dimensional models: approximate inference and numerical stability. While the individual components of the pipeline deserve their own tutorials, this tutorial focuses on incrementally combining those components.

TODO "great models aren't built in a day", "lose your first 100 games of go asap"

Workflow efficiency demands that code changes to upstream components don't break previous coding effort on downstream components. Pyro's approaches to this challenge include strategies for variational approximations ([pyro.infer.autoguide](https://docs.pyro.ai/en/stable/infer.autoguide.html)) and strategies for transforming model coordinate systems to improve geometry ([pyro.infer.reparam](https://docs.pyro.ai/en/stable/infer.reparam.html)).

#### Summary

- Great models can only be achieved by iterative development.
- Iterate quickly by building a pipeline that is robust to code changes.
- Start with a simple model and [mean-field inference](https://docs.pyro.ai/en/dev/infer.autoguide.html#autonormal).
- Avoid NANs by intellitently [initializing](https://docs.pyro.ai/en/dev/infer.autoguide.html#module-pyro.infer.autoguide.initialization) and [.clamp()](https://pytorch.org/docs/stable/generated/torch.clamp.html)ing.
- [Reparametrize](https://docs.pyro.ai/en/dev/infer.reparam.html) the model to improve geometry.
- Create a custom variational family by combining [AutoGuides](https://docs.pyro.ai/en/dev/infer.autoguide.html) or [EasyGuides](https://docs.pyro.ai/en/dev/contrib.easyguide.html).

#### Table of contents

1. [Clean the data](#Clean-the-data)
2. [Create a generative model](#Create-a-generative-model)
3. [Create an initialization heuristic](#Create-an-initialization-heuristic)
4. [Sanity check using mean-field inference](#Sanity-check)
5. [Reparametrize the model](#Reparametrize)
6. [Customize the variational family (autoguides, easyguides, custom guides](#Customize)

## Overview

Consider the problem of sampling from the posterior distribution of a probabilistic model with $10^5$ or more continuous latent variables, but whose data fits entirely in memory.
(For larger datasets, consider [amortized variational inference](http://pyro.ai/examples/svi_part_ii.html).) Inference in such high-dimensional models can be challenging even when posteriors are known to be [unimodal](https://en.wikipedia.org/wiki/Unimodality) or even [log-concave](https://arxiv.org/abs/1404.5886), due to correlations among latent variables.

To perform inference in such high-dimensional models in Pyro, we have evolved a [workflow](https://arxiv.org/abs/2011.01808) to incrementally build data analysis pipelines combining variational inference, MCMC, reparametrization effects, and ad-hoc initialization strategies. Our workflow is summarized as a sequence of steps, where validation after any step might suggest backtracking to change design decisions at a previous step.

1. Clean the data.
2. Create a generative model.
3. Create an initialization heuristic.
4. Sanity check using MAP or mean-field inference.
5. Reparameterize the model, evaluating results under mean field VI.
6. Customize the variational family (autoguides, easyguides, custom guides).

The crux of efficient workflow is to ensure changes don't break your pipeline. That is, after you build a number of pipeline stages, validate results, and decide to change one component in the pipeline, you'd like to minimize code changes needed in other components. The remainder of this tutorial describes these steps individually, then describes nuances of interactions among stages, then provides an example.

In [None]:
import os
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import (
    AutoDelta,
    AutoNormal,
    AutoMultivariateNormal,
    AutoLowRankMultivariateNormal,
    AutoGuideList,
    AutoStructured,
    AutoGaussian,
)
from pyro.infer.reparam import (
    AutoReparam,
    LocScaleReparam,
    HaarReparam,
    DiscreteCosineReparam,
    NeuTraReparam,
)
from pyro.optim import ClippedAdam
from pyro.ops.special import sparse_multinomial_likelihood
import matplotlib.pyplot as plt

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.7.0')

## Clean the data <a name="Clean-the-data"/>

While in principle you could be Bayesian about everything, it's best to focus probabilistic inference on the truly uncertain parts of your problem, and include unambiguous decision in data preprocessing.

## Create a generative model

The first step to using Pyro is creating a generative model, either a python function or a [pyro.nn.Module](https://docs.pyro.ai/en/dev/nn.html#pyro.nn.module.PyroModule). Start simple. Start with a shallow hierarchy and later add latent variables to share statistical strength. Start with a slice of your data then add a [plate]() over multiple slices. Start with simple distributions like [Normal](), [LogNormal](), [Poisson]() and [Multinomial](), then consider overdispersed versions like [StudentT](), [Gamma](), [GammaPoisson]()/[NegativeBinomial](), and [DirichletMultinomial](). Keep your model simple and readable so you can share it and get feedback from domain experts. Use [weakly informative priors](http://www.stat.columbia.edu/~gelman/presentations/weakpriorstalk.pdf).

## Create an initialization heuristic

In high-dimensional models, convergence can be slow and NANs arise easily, even when sampling from [weakly informative priors](http://www.stat.columbia.edu/~gelman/presentations/weakpriorstalk.pdf). We recommend heuristically initializing a point estimate for each latent variable, aiming to initialize at something that is the right order of magnitude. Often you can initialize to a simple statistic of the data, e.g. a mean or standard deviation.

Pyro's autoguides provide a number of [initialization strategies]() for initialzing the location parameter of many variational families, specified as `init_loc_fn`. You can create a custom initializer by accepting a pyro sample site dict and generating a sample from `site["name"]` and `site["fn"]` using e.g. `site["fn"].shape()`, `site["fn"].support`, `site["fn"].mean`, or sampling via `site["fn"].sample()`, e.g.

In [None]:
def init_loc_fn(site):
    if site["name"] == "x":
        return torch.zeros(site["fn"].shape())
    if site["name"] == "y":
        return torch.ones(site["fn"].shape())
    if site["name"] == "z":
        return torch.randn(site["fn"].shape()).mul_(0.01)
    return init_to_sample(site)  # fallback

As you evolve a model, you'll add and remove and rename latent variables. We find it useful to add a message to remind yourself to udpate the `init_loc_fn` whenever the model changes.

In [None]:
def init_loc_fn(site):
    if site["name"] == "x":
        return torch.zeros(site["fn"].shape())
    if site["name"] == "y":
        return torch.ones(site["fn"].shape())
    # ...
    raise NotImplementedError(f"TODO initialize latent variable {site['name']}")

## Sanity check using mean field inference

Mean field Normal inference is cheap and robust, and is a good way to sanity check your posterior point estimate, even if the posterior uncertainty is implausibly narrow. We recommend starting with an [AutoNormal](https://docs.pyro.ai/en/dev/infer.autoguide.html#autonormal) guide, and possibly setting `init_scale` to a small value like `init_scale=0.01` or `init_scale=0.001`.

Note that while MAP estimating via [AutoDelta]() is even cheaper and more robust than mean-field `AutoNormal`, `AutoDelta` is coordinate-system dependent and is not invariant to reparametrization. Because in our experience most models benefit from some reparameterization, we recommend `AutoNormal` over `AutoDelta` because `AutoNormal` is less sensitive to reparametrization.

In [None]:
guide = AutoNormal(model, init_loc_fn=init_loc_fn, init_scale=0.01)

def fit_svi(model, guide, lr=0.01, num_steps=1001):
    optim = ClippedAdam({"lr": lr, "lrd": 0.1 ** (1 / num_steps)})
    svi = SVI(model, guide, optim, Trace_ELBO())
    losses = []
    for step in range(num_steps):
        loss = svi.step(data)
        losses.append(loss)
        if step % 100 == 0:
            print(f"step {step} loss = {loss:0.6g}")
    return losses

## Reparametrize the model

Reparametrizing a model preserves its distribution while changing its geometry. Reparametrizing is simply a change of coordinates. When reparametrizing we aim to warp a model's geometry to remove correlations and to lift inconvenient topological manifolds into simpler higher dimensional flat Euclidean space.

Whereas many probabilistic programming langauges require users to rewrite models to change coordinates, Pyro implements a library of about 15 different reparametrization effects including decentering (Gorinova et al. 2020), Haar wavely transforms, and neural transport (Hoffman et al. 2019), as well as strategies to automatically apply effects and machinery to create custom reparametrization effects. Using these reparametrizers you can separate modeling from inference: first specify a model in a form that is natural to domain experts, then in inference code, reparametrize the model to have geometry that is more amenable to variational inference.

For example consider a mixed effects model.

In [None]:
def model(data):
    i_plate = pyro.plate("i", data.shape[-2], dim=-2)
    j_plate = pyro.plate("j", data.shape[-1], dim=-1)
    a = pyro.sample("a", dist.Normal(0, 1))
    b = pyro.sample("b", dist.Normal(0, 1))
    with i_plate:
        ai = pyro.sample("ai", dist.Normal(a, 1))
    with j_plate:
        bj = pyro.sample("bj", dist.Normal(b, 1))
    with i_plate, j_plate:
        pyro.sample("obs", dist.Normal(b + c, 1), obs=data)

while the geometry might improve if we changed
```diff
- ai = pyro.sample("ai", dist.Normal(a, 1))
+ ai = pyro.sample("ai", dist.Normal(0, 1)) + a
```
that might make the model less interpretable. Instead we can reparametrize the model

In [None]:
reparam_model = poutine.reparam(model, config={"ai": LocScaleReparam()})

or even automatically apply a set of recommended reparameterizers

In [None]:
reparam_model = AutoReparam()(model)

## Customize the variational family

When creating a new model, we recommend starting with mean field variational inference using an [AutoNormal]() guide. This mean field guide is good at finding the neighborhood of your model's mode, but naively it ignores correlations between latent variables. A first step in capturing correlations is to reparametrize the model as above: using a `LocScaleReparam` or `HaarReparam` already allows the guide to capture some correlations among latent variables.

The next step towards modeling uncertainty is to customize the variational family by trying other autoguides, building on [EasyGuide](), or creating a custom guide using Pyro primitives. We recommend increasing guide complexity gradually via these steps:
1. Start with an [AutoNormal]() guide.
2. Try [AutoLowRankMultivariateNormal](), which can model the principle components of correlated uncertainty. (For models with only ~100 latent variables you might also try [AutoMultivariateNormal]() or [AutoGaussian]()).
3. Try combining multiple guides [AutoGuideList](). For example if [AutoLowRankMultivariateNormal]() is too expensive for all the latent variables, you can use [AutoGuideList]() to combine an [AutoLowRankMultivariateNormal]() guide over a few top-level global latent variables, together with a cheaper [AutoNormal]() guide over more numerous local latent variables.
4. Try using [AutoGuideList]() to combine a autoguide together with a custom guide function built using `pyro.sample`, `pyro.param`, and `pyro.plate`. Given a `partial_guide()` function that covers just a few latent variables, you can `AutoGuideList.append(partial_guide)` just as you append autoguides.
5. Consider customizing one of Pyro's autoguides that leverage model structure, e.g. [AutoStructured](https://docs.pyro.ai/en/latest/infer.autoguide.html#autostructured), [AutoNormalMessenger](https://docs.pyro.ai/en/latest/infer.autoguide.html#autonormalmessenger), [AutoHierarchicalNormalMessenger](https://docs.pyro.ai/en/latest/infer.autoguide.html#autohierarchicalnormalmessenger) [AutoRegressiveMessenger](https://docs.pyro.ai/en/latest/infer.autoguide.html#autoregressivemessenger).
6. For models with local correlations, consider building on [EasyGuide](https://docs.pyro.ai/en/latest/contrib.easyguide.html), a framework for building guides over groups of variables.

While a fully-custom guides built from `pyro.sample` primitives offer the most flexible variational family, they are also the most brittle guides because each code change to the model or reparametrizer requires changes in the guide. The author recommends avoiding completely low-level guides and instead using `AutoGuide` or `EasyGuide` for at least some parts of the model, thereby speeding up model iteration.

## Example

In [None]:
dataset = torch.load(os.path.expanduser("~/pyro-cov/results/nextstrain.data.pt"))
def summarize(x, name=""):
    if isinstance(x, dict):
        for k, v in sorted(x.items()):
            summarize(v, name + "." + k if name else k)
    elif isinstance(x, torch.Tensor):
        print(f"{name}: {type(x).__name__} of shape {tuple(x.shape)}")
    elif isinstance(x, list):
        print(f"{name}: {type(x).__name__} of length {len(x)}")
    else:
        print(f"{name}: {type(x).__name__}")
summarize(dataset)

In [None]:
def model(dataset):
    features = dataset["features"]
    counts = dataset["counts"]
    assert features.shape[0] == counts.shape[-1]
    S, M = features.shape
    T, P, S = counts.shape
    time = torch.arange(float(T)) * 5.5 / dataset["time_step_days"]
    time -= time.mean()
    strain_plate = pyro.plate("strain", S, dim=-1)
    place_plate = pyro.plate("place", P, dim=-2)
    time_plate = pyro.plate("time", T, dim=-3)

    # Model each region as multivariate logistic growth.
    rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
    init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))

    coef = pyro.sample("coef", dist.Laplace(torch.zeros(M), 0.01).to_event(1))
    with strain_plate:
        rate_loc = pyro.deterministic("rate_loc", 0.01 * coef @ features.T)
    with place_plate, strain_plate:
        rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale))
        init = pyro.sample("init", dist.Normal(0, init_scale))
    logits = init + rate * time[:, None, None]
    
    # Observe sequences via a multinomial likelihood.
    with time_plate, place_plate:
        pyro.sample(
            "obs",
            dist.Multinomial(logits=logits.unsqueeze(-2), validate_args=False),
            obs=counts.unsqueeze(-2),
        )

In [None]:
def model(dataset):
    features = dataset["features"]
    counts = dataset["counts"]
    sparse_counts = dataset["sparse_counts"]
    assert features.shape[0] == counts.shape[-1]
    S, M = features.shape
    T, P, S = counts.shape
    time = torch.arange(float(T)) * 5.5 / dataset["time_step_days"]
    time -= time.mean()

    # Model each region as multivariate logistic growth.
    rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
    init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))
    coef = pyro.sample("coef", dist.Laplace(torch.zeros(M), 0.01).to_event(1))
    with pyro.plate("strain", S, dim=-1):
        rate_loc = pyro.deterministic("rate_loc", 0.01 * coef @ features.T)
        with pyro.plate("place", P, dim=-2):
            rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale))
            init = pyro.sample("init", dist.Normal(0, init_scale))
    logits = init + rate * time[:, None, None]

    # Observe sequences via a cheap sparse multinomial likelihood.
    logits = logits.log_softmax(-1)
    sparse_logits = logits[sparse_counts["index"].unbind(0)]
    pyro.factor(
        "obs",
        sparse_multinomial_likelihood(
            sparse_counts["total"], sparse_logits, sparse_counts["value"]
        )
    )

In [None]:
def fit_svi(lr=0.1, num_steps=1001, log_every=1):
    pyro.clear_param_store()
    pyro.set_rng_seed(20211205)
    guide = AutoNormal(model, init_scale=0.01)
    optim = ClippedAdam({"lr": lr, "lrd": 0.1 ** (1 / num_steps)})
    svi = SVI(model, guide, optim, Trace_ELBO())
    losses = []
    num_obs = dataset["counts"].ne(0).sum().item()
    for step in range(num_steps):
        loss = svi.step(dataset) / num_obs
        losses.append(loss)
        if step % log_every == 0:
            print(f"step {step: >4d} loss = {loss:0.6g}")
    plt.plot(losses)
    return guide

In [None]:
%%time
guide = fit_svi(log_every=10, num_steps=101)

In [None]:
def plot_volcano(guide):
    with torch.no_grad(), poutine.mask(mask=False), pyro.plate("particles", 100, dim=-3):
        coef = poutine.trace(guide).get_trace().nodes["coef"]["value"]
    coef = coef.squeeze()
    mean = coef.mean(0)
    std = coef.std(0)
    z_score = mean.abs() / std
    plt.scatter(mean.exp().numpy(), z_score.numpy(), lw=0, s=5, alpha=0.5, color="black")
    plt.yscale("symlog")
    plt.xscale("log")
    plt.ylim(0, None)
    plt.xlabel("$R_m/R_{wt}$")
    plt.ylabel("z-score")
    plt.title(f"Volcano plot of {len(mean)} mutations")
plot_volcano(guide)