In [1]:
# Import modules
import os
from functools import partial
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro import optim
from pyro.infer import SVI, Trace_ELBO, Predictive
import pyro.poutine as poutine

# for CI testing
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.0.0')
pyro.enable_validation(True)
pyro.set_rng_seed(1)


# Set matplotlib settings
%matplotlib inline
plt.style.use('default')

In [2]:
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])

# Standard linear regression - nothing fancy here
# Dataset: Add a feature to capture the interaction between "cont_africa" and "rugged"
df["cont_africa_x_rugged"] = df["cont_africa"] * df["rugged"]
data = torch.tensor(df[["cont_africa", "rugged", "cont_africa_x_rugged", "rgdppc_2000"]].values,
                        dtype=torch.float)
x_data, y_data = data[:, :-1], data[:, -1]

In [3]:
# Bayesian linear regression
class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean
# AutoDiagonalNormal guide that models the distribution of unobserved parameters
# in the model as a Gaussian with diagonal covariance, i.e. it assumes 
# that there is no correlation amongst the latent variables
# Under the hood, this defines a guide that uses a Normal distribution
# with learnable parameters corresponding to each sample statement in the model

model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)

In [4]:
"""
We first review the basic usage pattern of SVI objects in Pyro.
We assume that the user has defined a model and a guide.
The user then creates an optimizer and an SVI object:

"""
pyro.clear_param_store()
optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)},{"clip_norm": 10.0})
svi = SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())

for j in range(1000):
    # calculate the loss and take a gradient step
    loss = svi.step(x_data, y_data)
    # Loss scaling 
    # loss = svi.step(x_data, y_data) / N_data
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

[iteration 0001] loss: 4.6074
[iteration 0101] loss: 4.1388
[iteration 0201] loss: 4.1059
[iteration 0301] loss: 3.8348
[iteration 0401] loss: 3.7966
[iteration 0501] loss: 3.6876
[iteration 0601] loss: 3.5647
[iteration 0701] loss: 3.5466
[iteration 0801] loss: 3.4873
[iteration 0901] loss: 3.4300


In [5]:
def my_custom_L2_regularizer(my_parameters):
    reg_loss = 0.0
    for param in my_parameters:
        reg_loss = reg_loss + param.pow(2.0).sum()
    return reg_loss

In [6]:
"""
If we want more control, we can directly manipulate the differentiable loss
method of the various ELBO classes. 
For example, (assuming we know all the parameters in advance) 
this is equivalent to the previous code snippet:
"""
pyro.clear_param_store()
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)

loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

# We actually need to trace the parameters of the guide not of the model
trace = poutine.trace(guide).get_trace(x_data)
params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]

# define optimizer and loss function
optimizer = torch.optim.Adam(params, lr=0.001)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
# compute loss
for j in range(1000):
    # calculate the loss and take a gradient step and add a custom loss
    loss = loss_fn(model, guide,x_data,y_data) + my_custom_L2_regularizer(params)
    loss.backward()
    
    # Clip the gradients after the backward step
    nn.utils.clip_grad_value_(params, 5.0)
    # take a step and zero the parameter gradients
    optimizer.step()
    optimizer.zero_grad()
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

{'obs'}
  guide_vars - aux_vars - model_vars))


[iteration 0001] loss: 13.0063
[iteration 0101] loss: 12.0189
[iteration 0201] loss: 11.3696
[iteration 0301] loss: 10.3619
[iteration 0401] loss: 9.5758
[iteration 0501] loss: 8.7246
[iteration 0601] loss: 8.1958
[iteration 0701] loss: 7.6023
[iteration 0801] loss: 7.1693
[iteration 0901] loss: 6.6601


In [8]:
"""
For some models the loss gradient can explode during training,
leading to overflow and NaN values.
One way to protect against this is with gradient clipping.
The optimizers in pyro.optim take an optional dictionary of clip_args
which allows clipping either the gradient
norm or the gradient value to fall within the given limit.
To change the basic example above:

- optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)})
+ optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)}, {"clip_norm": 10.0})
"""

'\nFor some models the loss gradient can explode during training,\nleading to overflow and NaN values.\nOne way to protect against this is with gradient clipping.\nThe optimizers in pyro.optim take an optional dictionary of clip_args\nwhich allows clipping either the gradient\nnorm or the gradient value to fall within the given limit.\nTo change the basic example above:\n\n- optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)})\n+ optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)}, {"clip_norm": 10.0})\n'