# SMMALA sampling of MLP weights using XOR data

Learn the XOR function by sampling the weights of a multi-layer perceptron (MLP) via SMMALA.

In [1]:
## Import packages

import os
import csv

import numpy as np

import torch

from torch.utils.data import DataLoader
from torch.distributions import Normal

from eeyore.data import XOR
from eeyore.stats import softabs
from eeyore.models import mlp
from eeyore.kernels import NormalTransitionKernel
from eeyore.mcmc import SMMALA

from timeit import default_timer as timer
from datetime import timedelta

import matplotlib.pyplot as plt
import seaborn as sns

In [7]:
## Load XOR data

xor = XOR(dtype=torch.float64)
dataloader = DataLoader(xor, batch_size=4)

In [8]:
## Setup MLP model

hparams = mlp.Hyperparameters(dims=[2, 2, 1])
model = mlp.MLP(hparams=hparams, dtype=torch.float64)
model.prior = Normal(
    torch.zeros(model.num_params(), dtype=model.dtype),
    np.sqrt(3)*torch.ones(model.num_params(), dtype=model.dtype)
)

In [9]:
## Setup SMMALA sampler

theta0 = model.prior.sample()
sampler = SMMALA(model, theta0, dataloader, step=0.2, transform=lambda hessian: softabs(hessian, a=1000.))

In [10]:
## Run SMMALA sampler

start_time = timer()

sampler.run(num_iterations=1100, num_burnin=100)

end_time = timer()
print("Time taken: {}".format(timedelta(seconds=end_time-start_time)))

Time taken: 0:00:08.512275


In [11]:
## Compute acceptance rate

print("Acceptance rate: {}".format(sampler.chain.acceptance_rate()))

Acceptance rate: 0.5630000233650208


In [None]:
## Plot traces of simulated Markov chain

for i in range(model.num_params()):
    print("Generating trace plot of parameter {}".format(i+1))

    chain = sampler.chain.get_theta(i)
    plt.figure()
    sns.lineplot(range(len(chain)), chain)
    plt.xlabel('Iteration')
    plt.ylabel('Parameter value')
    plt.title(r'Traceplot of parameter $\theta_{}$'.format(i+1))

In [None]:
## Plot running means of simulated Markov chain

for i in range(model.num_params()):
    print("Generating running mean plot of parameter {}".format(i+1))

    chain = sampler.chain.get_theta(i)
    chain_mean = torch.empty(len(chain))
    chain_mean[0] = chain[0]
    for j in range(1, len(chain)):
        chain_mean[j] = (chain[j]+j*chain_mean[j-1])/(j+1)
        
    plt.figure()
    sns.lineplot(range(len(chain)), chain_mean)
    plt.xlabel('Iteration')
    plt.ylabel('Parameter value')
    plt.title(r'Running mean of parameter $\theta_{}$'.format(i+1))

In [None]:
## Plot histograms of simulated Markov chain

for i in range(model.num_params()):
    print("Generating histogram of parameter {}".format(i+1))

    plt.figure()
    sns.distplot(sampler.chain.get_theta(i), bins=20, norm_hist=True)
    plt.xlabel('Value range')
    plt.ylabel('Relative frequency')
    plt.title(r'Histogram of parameter $\theta_{}$'.format(i+1))

In [None]:
## Save simulated Markov chain in file

for i in range(model.num_params()):
    print("Saving chain of parameter {}".format(i+1))

    chain = sampler.chain.get_theta(i)
    with open(os.path.join("output", str("smmala_chain{:02d}.txt".format(i+1))), 'w') as file:
        writer = csv.writer(file)
        for state in chain:
            writer.writerow([state])

In [None]:
## Save acceptance diagnostic for simulated Markov chain

with open(os.path.join("output", "smmala_accepted.txt"), 'w') as file:
    writer = csv.writer(file)
    for a in sampler.chain.vals['accepted']:
        writer.writerow([a])

In [None]:
## Save runtime of MC simulation

with open(os.path.join("output", "smmala_runtime.txt"), 'w') as file:
    file.write(str("Runtime: {}".format(timedelta(seconds=end_time-start_time))))
    file.write("\n")