# Interactive posterior predictives checks

This notebook demonstrates how to interactively examine model priors using [ipywidgets](https://ipywidgets.readthedocs.io/en/stable/).

⚠️ This notebook is intended to be run interactively. Please run locally or [Open in Colab](https://colab.research.google.com/github/pyro-ppl/pyro/blob/dev/tutorial/source/prior-predictive.ipynb).

The first step in [Bayesian workflow](https://arxiv.org/abs/2011.01808) is to create a model. The second step is to check prior samples from the model. This notebook shows how to interactively check prior samples and tune parameters of the top level prior distribution while visualizing model outputs.

#### Summary

- Wrap your model in a plotting function.
- Use [ipywidgets.interact()](https://ipywidgets.readthedocs.io/en/stable/examples/Using%20Interact.html) to create sliders for each parameter of your prior.
- For expensive models, cache sampling using the [ResamplingCache](https://docs.pyro.ai/en/stable/infer.util.html#pyro.infer.resampler.ResamplingCache).

In [1]:
import os
from ipywidgets import interact, FloatSlider
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.resampler import ResamplingCache

assert pyro.__version__.startswith('1.8.1')
smoke_test = ('CI' in os.environ)  # for CI testing only

In [2]:
def model(T: int = 1000, data=None):
    # Sample parameters from the prior.
    df = pyro.sample("df", dist.LogNormal(0, 1))
    p_scale = pyro.sample("p_scale", dist.LogNormal(0, 1))  # process noise
    m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1))  # measurement noise
    
    # Simulate a time series.
    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)

In [3]:
def plot_trajectory(df=1.0, p_scale=1.0, m_scale=1.0):
    pyro.set_rng_seed(12345)
    data = {
        "df": torch.as_tensor(df),
        "p_scale": torch.as_tensor(p_scale),
        "m_scale": torch.as_tensor(m_scale),
    }
    trajectory = poutine.condition(model, data)()
    plt.figure(figsize=(8, 4)).patch.set_color("white")
    plt.plot(trajectory)
    plt.xlabel("time")
    plt.ylabel("obs")

Now we can examine what model trajectories look like for particular values top level latent variables.

In [4]:
interact(
    plot_trajectory,
    df=FloatSlider(value=1.0, min=0.01, max=10.0),
    p_scale=FloatSlider(value=0.1, min=0.01, max=1.0),
    m_scale=FloatSlider(value=1.0, min=0.01, max=10.0),
);

interactive(children=(FloatSlider(value=1.0, description='df', max=10.0, min=0.01), FloatSlider(value=0.1, des…

But to tune the parameters of our priors, we'd like to look at an ensmble of trajectories each of whose top-level parameters is sampled from the current prior.  Let's rewrite our model so we can input the prior parameters.

In [5]:
def model2(T: int = 1000, data=None, df0=0, df1=1, p0=0, p1=1, m0=0, m1=1):
    # Sample parameters from the prior.
    df = pyro.sample("df", dist.LogNormal(df0, df1))
    p_scale = pyro.sample("p_scale", dist.LogNormal(p0, p1))  # process noise
    m_scale = pyro.sample("m_scale", dist.LogNormal(m0, m1))  # measurement noise
    
    # Simulate a time series.
    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)

In [6]:
def plot_trajectories(**kwargs):
    pyro.set_rng_seed(12345)
    with pyro.plate("trajectories", 20, dim=-2):
        trajectories = model2(**kwargs)
    plt.figure(figsize=(8, 5)).patch.set_color("white")
    plt.plot(trajectories.T)
    plt.xlabel("time")
    plt.ylabel("obs")

In [7]:
interact(
    plot_trajectories,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
);

interactive(children=(FloatSlider(value=0.0, description='df0', max=5.0, min=-5.0), FloatSlider(value=1.0, des…

Yikes! It looks like our initial priors generated very weird trajectories, but we can slide to find better priors.  Try increasing `df0`.

## ResamplingCache

For more expensive simulations, sampling may be too slow to interactively generate samples at each time step. As a computational trick we can cache samples drawn from past priors and reuse them in similar priors -- provided we importance sample or resample. Pyro provides an importance [ResamplingCache](https://docs.pyro.ai/en/stable/infer.util.html#pyro.infer.resampler.ResamplingCache) to aid in interactively visualizing expensive models.

We'll start with our original model and create two wrappers: a cached model and a plotting function.

In [8]:
def conditioned_model(df, p_scale, m_scale):
    # Note we unsqueeze to support batching.
    df = torch.as_tensor(df)[..., None]
    p_scale = torch.as_tensor(p_scale)[..., None]
    m_scale = torch.as_tensor(m_scale)[..., None]
    
    data = {"df": df, "p_scale": p_scale, "m_scale": m_scale}
    with pyro.plate("trajectories", len(df), dim=-2):
        return poutine.condition(model, data)()

cache = ResamplingCache(conditioned_model, batch_size=100)

The `.cache()` method takes a dictionary of prior distributions. Note these prior distributions cannot depend on each other.

In [9]:
def plot_cached(df0, df1, p0, p1, m0, m1):
    prior = {
        "df": dist.LogNormal(df0, df1),
        "p_scale": dist.LogNormal(p0, p1),
        "m_scale": dist.LogNormal(m0, m1),
    }
    samples = cache.sample(prior, num_samples=20)
    trajectories = torch.stack(samples)
    plt.figure(figsize=(8, 5)).patch.set_color("white")
    plt.plot(trajectories.T)
    plt.xlabel("time")
    plt.ylabel("obs")

In [10]:
interact(
    plot_cached,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
);

interactive(children=(FloatSlider(value=0.0, description='df0', max=5.0, min=-5.0), FloatSlider(value=1.0, des…

After deciding on good prior parameters, we can then hard-code those into the model:

In [11]:
def model(T: int = 1000, data=None):
    df = pyro.sample("df", dist.LogNormal(4, 1))  # <-- changed 0 to 4
    p_scale = pyro.sample("p_scale", dist.LogNormal(1, 1))  # <-- changed 0 to 1
    m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1))

    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)