In [1]:
# Original tutorial: https://pyro.ai/examples/bayesian_regression_ii.html

%reset -sf
import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO, Predictive, MCMC, NUTS
from pyro.infer.autoguide import AutoMultivariateNormal, init_to_mean



pyro.set_rng_seed(1)
assert pyro.__version__.startswith('1.0.0')

In [2]:
%matplotlib inline
plt.style.use('default')

logging.basicConfig(format='%(message)s', level=logging.INFO)
# Enable validation checks
pyro.enable_validation(True)
smoke_test = ('CI' in os.environ)
pyro.set_rng_seed(1)
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
rugged_data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")

In [3]:
# Explicitly write out each parameter in the Guide or in the Model

def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.plate("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

def guide(is_cont_africa, ruggedness, log_gdp):
    a_loc = pyro.param('a_loc', torch.tensor(0.))
    a_scale = pyro.param('a_scale', torch.tensor(1.),
                         constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', torch.tensor(1.),
                             constraint=constraints.positive)
    
    sigma_scale = pyro.param('sigma_scale', torch.tensor(0.05),
                             constraint=constraints.positive)
    weights_loc = pyro.param('weights_loc', torch.randn(3))
    weights_scale = pyro.param('weights_scale', torch.ones(3),
                               constraint=constraints.positive)
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b_a = pyro.sample("bA", dist.Normal(weights_loc[0], weights_scale[0]))
    b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1]))
    b_ar = pyro.sample("bAR", dist.Normal(weights_loc[2], weights_scale[2]))
    sigma = pyro.sample("sigma", dist.Normal(sigma_loc, sigma_scale))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

In [4]:
# Utility function to print latent sites' quantile information.
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

# Prepare training data
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)

In [5]:
svi = SVI(model,
          guide,
          optim.Adam({"lr": .05}),
          loss=Trace_ELBO())

is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]
pyro.clear_param_store()
num_iters = 5000 if not smoke_test else 2
for i in range(num_iters):
    elbo = svi.step(is_cont_africa, ruggedness, log_gdp)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

Elbo loss: 5795.467590510845
Elbo loss: 415.8889375925064
Elbo loss: 250.01099556684494
Elbo loss: 247.14538943767548
Elbo loss: 249.2295293211937
Elbo loss: 250.9435819387436
Elbo loss: 249.57079249620438
Elbo loss: 248.7939082980156
Elbo loss: 248.6645919084549
Elbo loss: 250.3539196252823


In [6]:
num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
               for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
               if k != "obs"}

In [7]:
for site, values in summary(svi_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: a
       mean       std        5%       25%       50%       75%       95%
0  9.177097  0.059518  9.078332  9.140592  9.178284  9.217113  9.271454 

Site: bA
       mean       std        5%       25%       50%       75%       95%
0 -1.890339  0.122807 -2.088208 -1.978823 -1.887191 -1.803397 -1.700566 

Site: bR
       mean       std        5%       25%       50%       75%       95%
0 -0.157294  0.039547 -0.222703 -0.183125 -0.157319 -0.132543 -0.091143 

Site: bAR
       mean       std        5%       25%       50%       75%       95%
0  0.304762  0.067878  0.194515  0.259582  0.305156  0.349307  0.415692 

Site: sigma
      mean       std        5%       25%       50%       75%       95%
0  0.90384  0.048398  0.824407  0.870968  0.902914  0.936399  0.983218 



In [8]:
# we can also do exact inference with MCMC 
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

Sample: 100%|██████████| 1200/1200 [00:29, 40.18it/s, step size=3.57e-01, acc. prob=0.888]


In [9]:
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: a
       mean       std        5%       25%       50%       75%       95%
0  9.186106  0.144717  8.951578  9.091258  9.185141  9.274197  9.420919 

Site: bA
       mean       std        5%       25%       50%       75%       95%
0 -1.845644  0.218941 -2.218035 -1.982632 -1.843927 -1.709069 -1.491678 

Site: bR
       mean       std        5%       25%       50%       75%       95%
0 -0.186122  0.078746 -0.310428 -0.243365 -0.185347 -0.131482 -0.058996 

Site: bAR
       mean       std        5%       25%       50%      75%       95%
0  0.351446  0.122918  0.152581  0.272872  0.349664  0.43567  0.553147 

Site: sigma
      mean       std        5%       25%       50%       75%     95%
0  0.95116  0.051794  0.868952  0.916505  0.947113  0.983869  1.0391 



In [10]:
"""
As comparison to the previously obtained results from Diagonal Normal guide,
we will now use a guide that generates samples from a Cholesky factorization
of a multivariate normal distribution. This allows us to capture the correlations
between the latent variables via a covariance matrix. If we wrote this manually,
we would need to combine all the latent variables so we could sample a Multivarite Normal jointly.

"""
guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .01}),
          loss=Trace_ELBO())

is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]
pyro.clear_param_store()
for i in range(num_iters):
    elbo = svi.step(is_cont_africa, ruggedness, log_gdp)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

Elbo loss: 707.9290481805801
Elbo loss: 426.1432435512543
Elbo loss: 259.178986787796
Elbo loss: 249.957903444767
Elbo loss: 247.92973798513412
Elbo loss: 247.04828417301178
Elbo loss: 247.64889878034592
Elbo loss: 247.19207900762558
Elbo loss: 250.07714819908142
Elbo loss: 249.06675231456757
