In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import YouTubeVideo
from tqdm.notebook import tqdm
import corner

import torch
import pyro
print(f"Torch version: {torch.__version__}")
print(f"Pyro version: {pyro.__version__}")

%load_ext autoreload
%autoreload 2
from astro_utils import plot_params, plot_lc, plot_lc_folded, featurize_lc, plot_lc_features, make_train_plots

# Pyro basics

Example: Bayesian linear regression 

$$
y_i = w x_i + b +  \epsilon_i
$$

with $N$ observations $(x_i, y_i)$, Gaussian noise and Gaussian priors for $w$ and  $b$

The generative process in this case is 

- Sample $w \sim \mathcal{N}(\mu_w, \sigma_w^2)$
- Sample $b \sim \mathcal{N}(\mu_b, \sigma_b^2)$
- For $i=1,2,\ldots, N$
    - Sample $y_i \sim \mathcal{N}(w x_i + b, \sigma_\epsilon^2)$
    
where $\mu_w, \sigma_w, \mu_b, \sigma_b, \sigma_\epsilon$ are hyperparameters



## Writing a model

To write the model we use the submodules and primitives

- `pyro.distributions` to define prior/likelihoods 
- `pyro.sample` to define random variables (RV): Expects name, distribution and optionally observations
- `pyro.plate` for conditionally independent RV: Expects name, and size

In [None]:
from pyro.distributions import Normal, Uniform

def model(x, y=None): 
    w = pyro.sample("w", Normal(0.0, 1.0))
    b = pyro.sample("b", Normal(0.0, 1.0))
    s = pyro.sample("s", Uniform(0.0, 1.0))    
    with pyro.plate('dataset', size=len(x)):
        return pyro.sample("y", Normal(x*w + b, s), obs=y)

We can use 

- `pyro.infer.Predictive` 

to draw samples from the model

In [None]:
predictive = pyro.infer.Predictive(model, num_samples=500)

hatx = np.linspace(-6, 6, num=100).astype('float32') 
apriori_samples = predictive(torch.from_numpy(hatx))

# Plot samples from the priors
figure = corner.corner(np.stack([apriori_samples[var].detach().numpy()[:, 0] for var in ['b', 'w', 's']]).T, 
                       smooth=1., labels=["bias", "weight", "noise_std"], bins=20, 
                       quantiles=[0.16, 0.5, 0.84], 
                       show_titles=True, title_kwargs={"fontsize": 12})

# Plot posterior predictive of y given x
y_trace = apriori_samples["y"].detach().numpy()
med = np.median(y_trace, axis=0)
qua = np.quantile(y_trace, (0.05, 0.95), axis=0)

fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
ax.plot(hatx, y_trace.T);

## Inference

In the bayesian setting we want the posterior distribution 

$$
p(\theta | \mathcal{D}) = \frac{p(\mathcal{D}|\theta) p(\theta)}{\int_\theta p(\mathcal{D}|\theta) p(\theta)}
$$

where $\mathcal{D}$ is our dataset and $\theta = (w, b)$

For complex models the posterior is intractable. So we either do

- MCMC: Train a Markov chain to generate samples as if they came from the actual posterior: Sampling based
- Variational Inference: Choose a more simple posterior that is similar to the actual posterior: Optimization based



### Variational Inference

Propose an approximate (simple) posterior $q_\phi(\theta)$, e.g. factorized Gaussian

Optimize $\phi$ so that $q_\phi$ approximates $p(\theta|\mathcal{D})$

This is typically done by maximizing a lower bound on the evidence

$$
\mathcal{ELBO}(\phi) = \mathbb{E}_{\theta \sim q_\phi}[ \log p(\mathcal{D}|\theta)] - \text{KL}[q_\phi(\theta)|p(\theta)]
$$

- Maximize the likelihood of the model
- Minimize the distance between the approximate posterior and the prior

Once $q$ has been trained we use it as a replacement for $p(\theta|\mathcal{D})$ to calculate the **posterior predictive distribution**

$$
p(\mathbf{y}|\mathbf{x}, \mathcal{D}) = \int p(\mathbf{y}|\mathbf{x}, \theta) p(\theta| \mathcal{D}) \,d\theta
$$



## VI with Pyro

We use `pyro.infer.SVI` to perform **Stochastic Variational Inference**, which expects

- A generative model
- An approximate posterior (guide)
- Cost function: Typically ELBO
- Optimizer: How to optimize the ELBO, typically gradient descent based

We can use the `pyro.infer.autoguide` to create approximate posteriors from predefined recipes, for example a factorized Gaussian posterior (`AutoDiagonalNormal`)

In [None]:
pyro.enable_validation(True) # Useful for debugging
pyro.clear_param_store()

# Create a guide (approximate posterior)
from pyro.infer.autoguide import AutoDiagonalNormal as approx_posterior
guide = approx_posterior(model)


# Stochastic Variational Inference
svi = pyro.infer.SVI(model=model,  
                     guide=guide,
                     loss=pyro.infer.Trace_ELBO(), 
                     optim=pyro.optim.ClippedAdam({"lr": 1e-2}))

Lets consider a dataset with two observations $\mathcal{D} = \{ (-2, -2), (2, 2) \}$

`svi.step(x, y)` performs a gradient ascent step to maximize the ELBO

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(10, 2.5), dpi=80, tight_layout=True)

nepochs = 1000
loss = np.zeros(shape=(nepochs, ))
params = np.zeros(shape=(nepochs, 2, 3))

# Observed data
x = torch.tensor([-2., 2.])
y = torch.tensor([-2., 2.])

for k in tqdm(range(nepochs)):
    loss[k] = svi.step(x, y)
    
    phi = [param.detach().numpy() for param in guide.parameters()]
    params[k, 0, :] = phi[0] # Locations
    params[k, 1, :] = phi[1] # Scales    
    if np.mod(k, 10) == 0:
        plot_params(ax, k+1, loss, params)
        fig.canvas.draw()

Again we can use `pyro.infer.Predictive` to draw samples from the model

This time we use the guide to sample $w$ and $b$

In [None]:
predictive = pyro.infer.Predictive(model, 
                                   guide=guide, 
                                   num_samples=1000)

posterior_samples = predictive(torch.from_numpy(hatx))

# Plot posterior of w,  b and s
figure = corner.corner(np.stack([posterior_samples[var].detach().numpy()[:, 0] for var in ['b', 'w', 's']]).T, 
                       smooth=1., labels=["bias", "weight", "noise_std"], bins=20, 
                       quantiles=[0.16, 0.5, 0.84],
                       show_titles=True, title_kwargs={"fontsize": 12})

# Plot posterior predictive of y given x
y_trace = posterior_samples["y"].detach().numpy()
med = np.median(y_trace, axis=0)
qua = np.quantile(y_trace, (0.05, 0.95), axis=0)

fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
ax.plot(hatx, med)
ax.fill_between(hatx, qua[0], qua[1], alpha=0.5);

ax.errorbar(x.numpy(), y.numpy(), yerr=2*posterior_samples['s'].median().item(), 
           fmt='none', c='k', zorder=100);
