# SMMALA on MLP with XOR data; comparing PyTorch with numpy output

In [1]:
## Import packages

import numpy as np

import torch

from torch.utils.data import DataLoader

from eeyore.data import XOR
from eeyore.models import mlp
from eeyore.stats import softabs, softabs_np
from eeyore.mcmc import SMMALA, SMMALANP

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
## Load XOR data

xor = XOR()
dataloader = DataLoader(xor, batch_size=4)

In [3]:
## Initialize hyperparameters

hparams = mlp.Hyperparameters(dims=[2, 2, 1])

In [4]:
## Setup MLP model using PyTorch

model = mlp.MLP(hparams=hparams)

In [5]:
## Initialize parameters

theta0 = model.prior.sample()

In [6]:
## Setup SMMALA sampler using PyTorch

sampler = SMMALA(model, theta0, dataloader, step=0.25, transform=lambda hessian: softabs(hessian, 1000))

In [7]:
## Setup SMMALA sampler using numpy

sampler_np = SMMALANP(model, theta0.detach().cpu().numpy(), dataloader, step=0.25, transform=lambda hessian: softabs_np(hessian, 1000))

In [8]:
## Fix stochastic components of a draw via SMMALA

randn_val = torch.randn(sampler.model.num_params(), dtype=torch.float64)
threshold = torch.rand(1, dtype=torch.float64)

In [9]:
## Set starting state before draw

current_before_draw = torch.randn(sampler.model.num_params(), dtype=torch.float64)
current_before_draw

tensor([ 1.6612,  0.0945,  0.8177,  1.6956,  0.4440, -1.9902, -0.4069,  0.1285,
         1.2720], dtype=torch.float64)

In [10]:
## Set starting state in SMMALA and draw a new state using PyTorch

sampler.reset(current_before_draw)

proposed_torch_draw1, log_rate_torch_draw1, accept_torch_draw1 = sampler.draw(randn_val, threshold)

sampler.reset(current_before_draw)

proposed_torch_draw2, log_rate_torch_draw2, accept_torch_draw2 = sampler.draw(randn_val, threshold)

sampler_np.reset(current_before_draw)

proposed_np_draw, log_rate_np_draw, accept_np_draw = sampler_np.draw(randn_val.detach().cpu().numpy(), threshold.item())

In [11]:
## Display proposed parameter values

proposed_torch_draw1['theta'], \
proposed_torch_draw2['theta'], \
proposed_np_draw['theta']

(tensor([ 2.1631,  0.9101,  0.7747,  1.7720,  0.9561, -2.4636, -0.5775,  0.4754,
          1.3983], dtype=torch.float64, grad_fn=<AddBackward0>),
 tensor([ 2.1631,  0.9101,  0.7747,  1.7720,  0.9561, -2.4636, -0.5775,  0.4754,
          1.3983], dtype=torch.float64, grad_fn=<AddBackward0>),
 array([ 2.16309067,  0.91007237,  0.77470738,  1.77202912,  0.95606433,
        -2.46359143, -0.57752035,  0.47542532,  1.39825027]))

In [12]:
## Display inverse metric tensor of proposed parameter values

proposed_torch_draw1['inv_metric_val'], \
proposed_torch_draw2['inv_metric_val'], \
proposed_np_draw['inv_metric_val']

(tensor([[ 9.9892e-01, -7.0904e-03,  2.5933e-04, -7.5525e-05, -1.1878e-03,
           1.2520e-04,  1.8186e-03,  6.2095e-04,  2.8620e-03],
         [-7.0904e-03,  1.0053e+00, -6.7229e-04,  6.2455e-04,  3.1299e-03,
           2.6072e-04,  1.7795e-02,  1.4401e-03,  1.2957e-03],
         [ 2.5933e-04, -6.7229e-04,  1.0406e+00,  2.6184e-02, -3.6211e-04,
           4.0809e-02, -3.5442e-03, -1.6549e-01, -1.8052e-03],
         [-7.5525e-05,  6.2455e-04,  2.6184e-02,  1.0307e+00,  8.4859e-04,
           3.3610e-02, -6.5975e-03, -1.4740e-01, -6.4924e-03],
         [-1.1878e-03,  3.1299e-03, -3.6211e-04,  8.4859e-04,  9.7932e-01,
           1.9740e-03, -7.4242e-02,  1.0210e-02,  5.0577e-02],
         [ 1.2520e-04,  2.6072e-04,  4.0809e-02,  3.3610e-02,  1.9740e-03,
           1.0272e+00, -1.2538e-02, -1.6339e-01, -1.3356e-02],
         [ 1.8186e-03,  1.7795e-02, -3.5442e-03, -6.5975e-03, -7.4242e-02,
          -1.2538e-02,  7.5404e-01, -7.7444e-02, -2.8327e-01],
         [ 6.2095e-04,  1.4401e-03

In [13]:
## Display logarithm of acceptance rates

log_rate_torch_draw1, \
log_rate_torch_draw2, \
log_rate_np_draw

(tensor(-0.1510, dtype=torch.float64, grad_fn=<SubBackward0>),
 tensor(-0.1510, dtype=torch.float64, grad_fn=<SubBackward0>),
 -0.15103053448700798)

In [14]:
accept_torch_draw1, \
accept_torch_draw2, \
accept_np_draw

(1, 1, 1)