In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import os
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam

In [3]:
class Linear(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_size, out_size))
        self.bias = nn.Parameter(torch.randn(out_size))

    def forward(self, input_):
        return self.bias + input_ @ self.weight


In [4]:
linear = Linear(2,2)
to_pyro_module_(linear)  # to pyro module: this operates in-place

print(linear.bias)

Parameter containing:
tensor([ 0.8123, -1.1436], requires_grad=True)


In [5]:
example_input = torch.randn(3, 2)
example_output = linear(example_input)
print(example_output)

tensor([[ 1.6131, -2.0254],
        [ 2.5195, -2.0059],
        [ 1.0724, -0.9205]], grad_fn=<AddBackward0>)


In [79]:
class BayesianLinear(PyroModule):
    def __init__(self, in_size, out_size):
        super().__init__()
        # self.weight = PyroSample(dist.Normal(0, 1).expand([in_size, out_size]).to_event(2))
        self.weight = torch.randn(in_size, out_size)  # replace with some fixed constant will just not update these values
        # self.weight = PyroModule[nn.Linear](in_size, out_size) # this will not work
        self.bias = PyroSample(dist.Normal(0, 1).expand([out_size]).to_event(1))

    def forward(self, input_):
        return self.bias + input_ @ self.weight


In [80]:
blinear = Linear(2,2)
x = torch.randn(3, 2)
y = linear(x)
print(y)

tensor([[ 2.0048, -1.9638],
        [-0.1089, -0.6803],
        [-2.0456, -0.1828]], grad_fn=<AddBackward0>)


In [95]:
class Model(PyroModule):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.linear = BayesianLinear(in_size, out_size)  # this is a PyroModule
        self.obs_scale = PyroSample(dist.LogNormal(0, 1))
        # self.obs_scale = pyro.sample("sigma", dist.Uniform(0., 1.).expand([1]).to_event(1))

    def forward(self, input, output=None):
        obs_loc = self.linear(input)  # this samples linear.bias and linear.weight
        obs_scale = self.obs_scale    # this samples self.obs_scale
        with pyro.plate("instances", len(input)):
            return pyro.sample("obs", dist.Normal(obs_loc, obs_scale).to_event(1),
                               obs=output)


In [97]:
pyro.clear_param_store()
pyro.set_rng_seed(1)

in_size, out_size = 2,3
model = Model(in_size,out_size)
x = torch.randn(10, in_size)
y = torch.randn(10, out_size)

guide = AutoNormal(model)  # unlearned posterior dist. AutoDiagonalNormal
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())  # parameters to optimize are determined by guide()
print(x,y,y.numel())
for step in range(1000):
    loss = svi.step(x, y) / y.numel()  # data in step() are passed to both model() and guide()
    if step % 100 == 0:
        # print(model.linear.weight)
        print("step {} loss = {:0.4g}".format(step, loss))

tensor([[-0.5912,  0.2738],
        [-0.9649, -0.2358],
        [ 1.8793, -0.0721],
        [ 0.1578, -0.7735],
        [ 0.1991,  0.0457],
        [ 0.1530, -0.4757],
        [-0.1110,  0.2927],
        [-0.1578, -0.0288],
        [ 2.3571, -1.0373],
        [ 1.5748, -0.6298]]) tensor([[-0.9274,  0.5451,  0.0663],
        [-0.4370,  0.7626,  0.4415],
        [ 1.1651,  2.0154,  0.1374],
        [ 0.9386, -0.1860, -0.6446],
        [ 1.5392, -0.8696,  0.2579],
        [ 1.0950, -0.5065,  0.0998],
        [-0.6540,  0.7317, -1.4344],
        [-0.5008,  0.0938, -1.2597],
        [ 0.2546, -0.5020, -1.0412],
        [ 0.7323, -1.0483, -0.4709]]) 30
step 0 loss = 1.617
step 100 loss = 1.522
step 200 loss = 1.489
step 300 loss = 1.5
step 400 loss = 1.468
step 500 loss = 1.493
step 600 loss = 1.491
step 700 loss = 1.501
step 800 loss = 1.474
step 900 loss = 1.488


In [50]:
# examine the optimized parameter values
guide.requires_grad_(False)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

AutoNormal.locs.linear.bias Parameter containing:
tensor([-1.2041])
AutoNormal.scales.linear.bias tensor([0.2056])
AutoNormal.locs.obs_scale Parameter containing:
tensor(-0.8343)
AutoNormal.scales.obs_scale tensor(0.4281)


In [51]:
# evaluate the trained model: posterior dist.
from pyro.infer import Predictive


def summary(samples):
    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0),
            "std": torch.std(v, 0),
            "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
            "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
        }
    return site_stats

x_data = torch.randn(2, in_size)

predictive = Predictive(model, guide=guide, num_samples=20,
                        return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(x_data)
pred_summary = summary(samples)
# print(samples)
print(pred_summary)

{'obs': {'mean': tensor([[-0.2483],
        [-1.2851]]), 'std': tensor([[0.6681],
        [0.5240]]), '5%': tensor([[-1.3640],
        [-2.5116]]), '95%': tensor([[ 0.8016],
        [-0.6935]])}, '_RETURN': {'mean': tensor([[-0.2483],
        [-1.2851]]), 'std': tensor([[0.6681],
        [0.5240]]), '5%': tensor([[-1.3640],
        [-2.5116]]), '95%': tensor([[ 0.8016],
        [-0.6935]])}}


In [30]:
x_data = torch.randn(1, in_size)

predictive = Predictive(model, guide=guide, num_samples=8000,
                        return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(x_data)
pred_summary = summary(samples)
# print(samples)
print(pred_summary)

{'linear.weight': {'mean': tensor([[[-0.1136],
         [-0.4577]]]), 'std': tensor([[[0.6488],
         [0.6626]]]), '5%': tensor([[[-1.1910],
         [-1.5464]]]), '95%': tensor([[[0.9701],
         [0.6256]]])}, 'obs': {'mean': tensor([[-1.2794]]), 'std': tensor([[1.0144]]), '5%': tensor([[-2.9126]]), '95%': tensor([[0.3741]])}, '_RETURN': {'mean': tensor([[-1.2794]]), 'std': tensor([[1.0144]]), '5%': tensor([[-2.9126]]), '95%': tensor([[0.3741]])}}


In [72]:
class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        # self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        # self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs1", dist.Normal(mean, sigma), obs=y)
        return mean

pyro.clear_param_store()
pyro.set_rng_seed(1)

in_size, out_size = 2,1
model = BayesianRegression(in_size,out_size)
x = torch.randn(2, in_size)
y = model(x)

guide = AutoNormal(model)  # unlearned posterior dist. AutoDiagonalNormal
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())  # parameters to optimize are determined by guide()
print(x,y,y.numel())
for step in range(1000):
    loss = svi.step(x, y) / y.numel()  # data in step() are passed to both model() and guide()
    if step % 100 == 0:
        print(model.linear.weight)
        print("step {} loss = {:0.4g}".format(step, loss))

tensor([[ 0.7244, -0.7022],
        [ 1.1661,  0.2605]]) tensor([0.3460, 0.2065], grad_fn=<SqueezeBackward1>) 2
Parameter containing:
tensor([[ 0.3643, -0.3121]], requires_grad=True)
step 0 loss = 3.812


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.