In [1]:
import sys
sys.path.append("..")

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import pyro
from pyro.infer.mcmc import MCMC
import pyro.distributions as dist

from kernel.sghmc import SGHMC

pyro.set_rng_seed(101)

In [2]:
train_dataset = datasets.MNIST('./data', train=True, download=True)

test_dataset = datasets.MNIST('./data', train=False, download=True)

perm = torch.randperm(len(train_dataset))
train_idx = perm[:len(train_dataset)*5//6]
val_idx = perm[len(train_dataset)*5//6:]
    
mean = 0.1307
std = 0.3081

X_train = (train_dataset.data[train_idx] / 255.0 - mean) / std
Y_train = train_dataset.targets[train_idx]

X_val = (train_dataset.data[val_idx] / 255.0 - mean) / std
Y_val = train_dataset.targets[val_idx]

X_test = (test_dataset.data / 255.0 - mean) / std
Y_test = test_dataset.targets

In [3]:
PyroLinear = pyro.nn.PyroModule[torch.nn.Linear]
    
class BNN(pyro.nn.PyroModule):
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        
        # TODO add gamma priors to precision terms
        self.fc1 = PyroLinear(input_size, hidden_size)
        self.fc1.weight = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([hidden_size, input_size]).to_event(2))
        self.fc1.bias   = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([hidden_size]).to_event(1))
        
        self.fc2 = PyroLinear(hidden_size, output_size)
        self.fc2.weight = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([output_size, hidden_size]).to_event(2))
        self.fc2.bias   = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([output_size]).to_event(1))
        
        self.relu = torch.nn.ReLU()
        self.log_softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, x, y=None):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.log_softmax(x)# output (log) softmax probabilities of each class
        
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Categorical(logits=x), obs=y)
            
bnn = BNN(28*28, 100, 10)

In [4]:
BATCH_SIZE = 500
LR = 2e-6
MOMENTUM_DECAY = 0.01
NUM_STEPS = 10

In [5]:
sghmc = SGHMC(bnn,
              subsample_positions=[0, 1],
              batch_size=BATCH_SIZE,
              learning_rate=LR,
              momentum_decay=MOMENTUM_DECAY,
              num_steps=NUM_STEPS)

sghmc_mcmc = MCMC(sghmc, num_samples=800, warmup_steps=50)
sghmc_mcmc.run(X_train, Y_train)
sghmc_samples = sghmc_mcmc.get_samples()

Sample: 100%|█████████████████████████████████████████████████████████████████| 850/850 [01:24, 10.11it/s, lr=2.00e-06]
