# Testing the effect of incorporating observed information

We test the effect on SGHMC of using observed information to estimate the noise model. We test three configurations:
- Not using observed information.
- Only calculating the estimate at setup time.
- Recalculating every sample.
- Recalculating before every step while simulating the dynamics in a sample.

We record the test accuracy and time taken for each sample.

In [None]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

import os.path
import json

import time
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

import seaborn as sns # conda install seaborn
import pandas as pd # ^^ this will automatically install pandas

import pyro
from pyro.infer.mcmc import MCMC
import pyro.distributions as dist

from kernel.sghmc import SGHMC
from kernel.sgld import SGLD
from kernel.sgd import SGD
from kernel.sgnuts import NUTS as SGNUTS

pyro.set_rng_seed(101)

plt.rcParams['figure.dpi'] = 300

In [None]:
# assert torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu" )

In [None]:
RESULTS_DIR = os.path.join("results", "obs-info")
RESULTS_NOINFO = os.path.join(RESULTS_DIR, "noinfo.json")
RESULTS_START = os.path.join(RESULTS_DIR, "start.json")
RESULTS_EVERY_SAMPLE = os.path.join(RESULTS_DIR, "every-sample.json")
RESULTS_EVERY_STEP = os.path.join(RESULTS_DIR, "every-step.json")

In [None]:
# Simple dataset wrapper class

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
        
    def __len__(self):
        return(len(self.data))
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

## Hyperparams

These hyperparameters were fixed during the hyperparameter search. All other hyperparameters in this notebook are the best ones we found during the hyperparameter search.

In [None]:
BATCH_SIZE = 500
NUM_EPOCHS = 800
WARMUP_EPOCHS = 50
HIDDEN_SIZE = 100

## Download MNIST and setup datasets / dataloaders

In [None]:
train_dataset = datasets.MNIST('./data', train=True, download=True)

test_dataset = datasets.MNIST('./data', train=False, download=True)

nvalid = 10000

perm = torch.arange(len(train_dataset))
train_idx = perm[nvalid:]
val_idx = perm[:nvalid]
    
mean = 0.1307
std = 0.3081

# scale the datasets
X_train = train_dataset.data[train_idx] / 255.0
Y_train = train_dataset.targets[train_idx]

X_val = train_dataset.data[val_idx] / 255.0
Y_val = train_dataset.targets[val_idx]

X_test = test_dataset.data / 255.0
Y_test = test_dataset.targets

# redefine the datasets
train_dataset = Dataset(X_train, Y_train)
val_dataset = Dataset(X_val, Y_val)
test_dataset = Dataset(X_test, Y_test)

# setup the dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Define the Bayesian neural network  model

In [None]:
PyroLinear = pyro.nn.PyroModule[torch.nn.Linear]
    
class BNN(pyro.nn.PyroModule):
    
    def __init__(self, input_size, hidden_size, output_size, prec=1., device='cpu'):
        super().__init__()
        # prec is a kwarg that should only used by SGD to set the regularization strength 
        # recall that a Guassian prior over the weights is equivalent to L2 norm regularization in the non-Bayes setting

        self.device = device
        
        # TODO add gamma priors to precision terms

        self.fc1 = PyroLinear(input_size, hidden_size)

        fc1_weight_loc = torch.zeros((hidden_size, input_size), device=self.device)
        fc1_weight_scale = torch.ones((hidden_size, input_size), device=self.device) * prec

        fc1_bias_loc = torch.zeros((hidden_size,), device=self.device)
        fc1_bias_scale = torch.ones((hidden_size,), device=self.device) * prec

        self.fc1.weight = pyro.nn.PyroSample(dist.Normal(fc1_weight_loc, fc1_weight_scale).to_event(2))
        self.fc1.bias   = pyro.nn.PyroSample(dist.Normal(fc1_bias_loc, fc1_bias_scale).to_event(1))
        
        self.fc2 = PyroLinear(hidden_size, output_size)

        fc2_weight_loc = torch.zeros((output_size, hidden_size), device=self.device)
        fc2_weight_scale = torch.ones((output_size, hidden_size), device=self.device) * prec

        fc2_bias_loc = torch.zeros((output_size,), device=self.device)
        fc2_bias_scale = torch.ones((output_size,), device=self.device) * prec

        self.fc2.weight = pyro.nn.PyroSample(dist.Normal(fc2_weight_loc, fc2_weight_scale).to_event(2))
        self.fc2.bias   = pyro.nn.PyroSample(dist.Normal(fc2_bias_loc, fc2_bias_scale).to_event(1))
        
        self.relu = torch.nn.ReLU()
        self.log_softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, x, y=None):
        x = x.view(-1, 28*28).to(self.device)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.log_softmax(x)# output (log) softmax probabilities of each class

        if y is not None:
            y = y.to(self.device)
        
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Categorical(logits=x), obs=y)

## Not using observed information

In [None]:
LR = 2e-6
MOMENTUM_DECAY = 0.01
RESAMPLE_EVERY_N = 0
NUM_STEPS = 1 # fixed during hypeparameter search

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10, device=device).to(device)

sghmc = SGHMC(bnn,
              subsample_positions=[0, 1],
              batch_size=BATCH_SIZE,
              learning_rate=LR,
              momentum_decay=MOMENTUM_DECAY,
              num_steps=NUM_STEPS,
              resample_every_n=RESAMPLE_EVERY_N,
              obs_info_noise=False,
              device=device)

sghmc_mcmc = MCMC(sghmc, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

noinfo_test_errs = []
noinfo_times = []

# full posterior predictive 
full_predictive = torch.FloatTensor(10000, 10)
full_predictive.zero_()

for epoch in range(1, 1+NUM_EPOCHS + WARMUP_EPOCHS):
    
    start = time.time()

    sghmc_mcmc.run(X_train, Y_train)
    
    if epoch >= WARMUP_EPOCHS:
        
        sghmc_samples = sghmc_mcmc.get_samples()
        predictive = pyro.infer.Predictive(bnn, posterior_samples=sghmc_samples)
        
        with torch.no_grad():
            epoch_predictive = None
            for x, y in val_loader:
                prediction = predictive(x)['obs'].to(torch.int64).to("cpu")
                if epoch_predictive is None:
                    epoch_predictive = prediction
                else:
                    epoch_predictive = torch.cat((epoch_predictive, prediction), dim=1)
                    
            for sample in epoch_predictive:
                predictive_one_hot = F.one_hot(sample, num_classes=10)
                full_predictive = full_predictive + predictive_one_hot
                
            full_y_hat = torch.argmax(full_predictive, dim=1)
            total = Y_val.shape[0]
            correct = int((full_y_hat == Y_val).sum())
            
        end = time.time()
        
        noinfo_test_errs.append(1.0 - correct/total)
        noinfo_times.append(end - start)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

# Save the errors to a file
with open(RESULTS_NOINFO, "w") as f:
    json.dump((noinfo_test_errs, noinfo_times), f)

## Estimating at setup time

In [None]:
LR = 2e-6
MOMENTUM_DECAY = 0.01
RESAMPLE_EVERY_N = 0
NUM_STEPS = 1 # fixed during hypeparameter search

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10, device=device).to(device)

sghmc = SGHMC(bnn,
              subsample_positions=[0, 1],
              batch_size=BATCH_SIZE,
              learning_rate=LR,
              momentum_decay=MOMENTUM_DECAY,
              num_steps=NUM_STEPS,
              resample_every_n=RESAMPLE_EVERY_N,
              obs_info_noise=True,
              compute_obs_info="start",
              device=device)

sghmc_mcmc = MCMC(sghmc, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

start_test_errs = []
start_times = []

# full posterior predictive 
full_predictive = torch.FloatTensor(10000, 10)
full_predictive.zero_()

for epoch in range(1, 1+NUM_EPOCHS + WARMUP_EPOCHS):
    
    start = time.time()

    sghmc_mcmc.run(X_train, Y_train)
    
    if epoch >= WARMUP_EPOCHS:
        
        sghmc_samples = sghmc_mcmc.get_samples()
        predictive = pyro.infer.Predictive(bnn, posterior_samples=sghmc_samples)
        
        with torch.no_grad():
            epoch_predictive = None
            for x, y in val_loader:
                prediction = predictive(x)['obs'].to(torch.int64).to("cpu")
                if epoch_predictive is None:
                    epoch_predictive = prediction
                else:
                    epoch_predictive = torch.cat((epoch_predictive, prediction), dim=1)
                    
            for sample in epoch_predictive:
                predictive_one_hot = F.one_hot(sample, num_classes=10)
                full_predictive = full_predictive + predictive_one_hot
                
            full_y_hat = torch.argmax(full_predictive, dim=1)
            total = Y_val.shape[0]
            correct = int((full_y_hat == Y_val).sum())
            
        end = time.time()
        
        start_test_errs.append(1.0 - correct/total)
        start_times.append(end - start)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

# Save the errors to a file
with open(RESULTS_START, "w") as f:
    json.dump((start_test_errs, start_times), f)

## Recalculating every sample

In [None]:
LR = 2e-6
MOMENTUM_DECAY = 0.01
RESAMPLE_EVERY_N = 0
NUM_STEPS = 1 # fixed during hypeparameter search

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10, device=device).to(device)

sghmc = SGHMC(bnn,
              subsample_positions=[0, 1],
              batch_size=BATCH_SIZE,
              learning_rate=LR,
              momentum_decay=MOMENTUM_DECAY,
              num_steps=NUM_STEPS,
              resample_every_n=RESAMPLE_EVERY_N,
              obs_info_noise=True,
              compute_obs_info="every_sample",
              device=device)

sghmc_mcmc = MCMC(sghmc, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

every_sample_test_errs = []
every_sample_times = []

# full posterior predictive 
full_predictive = torch.FloatTensor(10000, 10)
full_predictive.zero_()

for epoch in range(1, 1+NUM_EPOCHS + WARMUP_EPOCHS):
    
    start = time.time()
    
    sghmc_mcmc.run(X_train, Y_train)
    
    if epoch >= WARMUP_EPOCHS:
        
        sghmc_samples = sghmc_mcmc.get_samples()
        predictive = pyro.infer.Predictive(bnn, posterior_samples=sghmc_samples)
        start = time.time()
        
        with torch.no_grad():
            epoch_predictive = None
            for x, y in val_loader:
                prediction = predictive(x)['obs'].to(torch.int64).to("cpu")
                if epoch_predictive is None:
                    epoch_predictive = prediction
                else:
                    epoch_predictive = torch.cat((epoch_predictive, prediction), dim=1)
                    
            for sample in epoch_predictive:
                predictive_one_hot = F.one_hot(sample, num_classes=10)
                full_predictive = full_predictive + predictive_one_hot
                
            full_y_hat = torch.argmax(full_predictive, dim=1)
            total = Y_val.shape[0]
            correct = int((full_y_hat == Y_val).sum())
            
        end = time.time()
        
        every_sample_test_errs.append(1.0 - correct/total)
        every_sample_times.append(end - start)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

# Save the errors to a file
with open(RESULTS_EVERY_SAMPLE, "w") as f:
    json.dump((every_sample_test_errs, every_sample_times), f)

## Recalculating every step

In [None]:
LR = 2e-6
MOMENTUM_DECAY = 0.01
RESAMPLE_EVERY_N = 0
NUM_STEPS = 1 # fixed during hypeparameter search

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10, device=device).to(device)

sghmc = SGHMC(bnn,
              subsample_positions=[0, 1],
              batch_size=BATCH_SIZE,
              learning_rate=LR,
              momentum_decay=MOMENTUM_DECAY,
              num_steps=NUM_STEPS,
              resample_every_n=RESAMPLE_EVERY_N,
              obs_info_noise=True,
              compute_obs_info="every_step",
              device=device)

sghmc_mcmc = MCMC(sghmc, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

every_step_test_errs = []
every_step_times = []

# full posterior predictive 
full_predictive = torch.FloatTensor(10000, 10)
full_predictive.zero_()

for epoch in range(1, 1+NUM_EPOCHS + WARMUP_EPOCHS):
    
    start = time.time()
    
    sghmc_mcmc.run(X_train, Y_train)
    
    if epoch >= WARMUP_EPOCHS:
        
        sghmc_samples = sghmc_mcmc.get_samples()
        predictive = pyro.infer.Predictive(bnn, posterior_samples=sghmc_samples)
        start = time.time()
        
        with torch.no_grad():
            epoch_predictive = None
            for x, y in val_loader:
                prediction = predictive(x)['obs'].to(torch.int64).to("cpu")
                if epoch_predictive is None:
                    epoch_predictive = prediction
                else:
                    epoch_predictive = torch.cat((epoch_predictive, prediction), dim=1)
                    
            for sample in epoch_predictive:
                predictive_one_hot = F.one_hot(sample, num_classes=10)
                full_predictive = full_predictive + predictive_one_hot
                
            full_y_hat = torch.argmax(full_predictive, dim=1)
            total = Y_val.shape[0]
            correct = int((full_y_hat == Y_val).sum())
            
        end = time.time()
        
        every_step_test_errs.append(1.0 - correct/total)
        every_step_times.append(end - start)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

# Save the errors to a file
with open(RESULTS_EVERY_STEP, "w") as f:
    json.dump((every_step_test_errs, every_step_times), f)

### Plot the convergence dynamics and times

In [None]:
sns.set_style("dark")

# Load the previous results from the files
with open(RESULTS_NOINFO, "r") as f:
    noinfo_test_errs, noinfo_times = json.load(f)
with open(RESULTS_START, "r") as f:
    start_test_errs, start_times = json.load(f)
with open(RESULTS_EVERY_SAMPLE, "r") as f:
    every_sample_test_errs, every_sample_times = json.load(f)
with open(RESULTS_EVERY_STEP, "r") as f:
    every_step_test_errs, every_step_times = json.load(f)
    
noinfo_test_errs = np.array(noinfo_test_errs)
noinfo_times = np.array(noinfo_times)
start_test_errs = np.array(start_test_errs)
start_times = np.array(start_times)
every_sample_test_errs = np.array(every_sample_test_errs)
every_sample_times = np.array(every_sample_times)
every_step_test_errs = np.array(every_step_test_errs)
every_step_times = np.array(every_step_times)

In [None]:
err_dict = {
    'No obs info' : noinfo_test_errs, 
    'Start' : start_test_errs, 
    'Every sample' : every_sample_test_errs, 
    'Every step' : every_step_test_errs
}
x = np.arange(1, NUM_EPOCHS+1)
lst = []
for i in range(len(x)):
    for updater in err_dict.keys():
        lst.append([x[i], updater, err_dict[updater][i]])

df = pd.DataFrame(lst, columns=['iterations', 'updater','test error'])
sns.lineplot(data=df.pivot("iterations", "updater", "test error"))
plt.ylabel("test error")
plt.show() #dpi=300

In [None]:
time_dict = {
    'No obs info' : noinfo_times, 
    'Start' : start_times, 
    'Every sample' : every_sample_times, 
    'Every step' : every_step_times
}
x = np.arange(1, NUM_EPOCHS+1)
lst = []
for i in range(len(x)):
    for updater in err_dict.keys():
        lst.append([x[i], updater, time_dict[updater][i]])

df = pd.DataFrame(lst, columns=['iterations', 'updater','sample time'])
sns.lineplot(data=df.pivot("iterations", "updater", "sample time"))
plt.ylabel("sample time")
plt.show() #dpi=300