In [209]:
import numpy as np

import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets

from sklearn.model_selection import StratifiedKFold

from tqdm import tqdm_notebook as tqdm

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
x_train, y_train = mnist_trainset.data.float(), mnist_trainset.targets

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
x_test, y_test = mnist_testset.data.float(), mnist_testset.targets

#x_mean, x_std = x_train.mean(dim=[1, 2]), x_train.var(dim=[1, 2])**0.5
x_train, x_test = (x_train - 128)/256, (x_test - 128)/256
#y_train = torch.FloatTensor(60000, 10).scatter_(1, y_train, 1)
#y_test = torch.FloatTensor(10000, 10).scatter_(1, y_test, 1)

In [230]:
def accuracy(y, target):
    _, predicted = torch.max(y, 1)
    accuracy = (predicted == target).sum().float() / target.size(0)

    return accuracy

In [239]:
class HmcNet(nn.Module):
    
    def __init__(self, input_dim, num_classes, num_units, learn_rate, prior_sigma):
        super(HmcNet, self).__init__()
        self.input_dim = input_dim
        self.num_units = num_units
        self.num_classes = num_classes
        self.prior_sigma = prior_sigma
        
        self.layer1 = nn.Linear(input_dim, num_units)
        self.layer2 = nn.Linear(num_units, num_units)
        self.layer3 = nn.Linear(num_units, num_classes)
        
        self.loss = nn.CrossEntropyLoss()
        #self.optimiser = torch.optim.SGD(self.parameters(), lr=learn_rate, momentum=0.9)
        self.optimiser = torch.optim.Adam(self.parameters(), lr=learn_rate)
        
        
    def run_dynamics(self, momenta, x_train, y_train, epsilon, num_steps):
        """
        Leapfrog integration of the dynamics for

        time = num_steps * epsilon
        """
        
        # Computing dE_dz
        hmc_net.potential_energy(x_train, y_train).backward()
        
        # Sample integration direction
        epsilon = epsilon * np.random.choice([-1, 1])
        
        # First leapfrog update
        for param in self.parameters():
            momenta[param].data = momenta[param].data + epsilon / 2. * param.grad

        for i in range(num_steps - 1):

            # Middle leapfrog steps
            # z = z + epsilon * r
            for param in self.parameters():
                param.data = param.data + epsilon * momenta[param].data #param = param + epsilon * momenta[param]
                
            # Computing dE_dz
            hmc_net.potential_energy(x_train, y_train).backward()
            
            # r = r - epsilon * dE_dz(z)
            for param in self.parameters():
                momenta[param].data = momenta[param].data - epsilon * param.grad.data

        # Final leapfrog steps
        # z = z + epsilon * r
        for param in self.parameters():
            param.data = param.data + epsilon * momenta[param].data
    
        # Computing dE_dz
        hmc_net.potential_energy(x_train, y_train).backward()

        # r = r - epsilon / 2. * dE_dz(z)
        for param in self.parameters():
            momenta[param].data = momenta[param].data - epsilon / 2. * param.grad.data
        
        return momenta
    
        
    def forward(self, x):
        
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        
        y = F.softmax(self.layer3(x), dim=1)
        
        return y
    
    
    def fit_map(self, x_train, y_train, num_iterations, log_every):
        
        for i in range(num_iterations):
            
            self.optimiser.zero_grad()
            
            loss = self.potential_energy(x_train, y_train)
            loss.backward()
            
            self.optimiser.step()
            
            if i % log_every == 0:
                
                acc = accuracy(self.forward(x_train), y_train)
                print('Loss {:.3f}, accuracy {:.3f}'.format(loss, acc))
            
            
    def potential_energy(self, x, labels):
        
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        
        y = F.softmax(self.layer3(x), dim=1)
        
        s = 0
        for param in self.parameters():
            s = s + torch.sum(param**2)
        
        return self.loss(y, labels) + (0.5 / self.prior_sigma**2) * s
    
    
    def hamiltonian(self, potential, momenta):
        
        s = 0
        for k, v in momenta.items():
            s = s + torch.sum(v**2)
        
        return potential + 0.5 * s
    
    
    def hmc_sample(self, x_train, y_train, mixing_time=50, burn_in_time=10,
                   num_integral_steps=10, num_samples=100, epsilon=1e-3, log_every=1):
        
        num_iters = num_samples * mixing_time + burn_in_time
        num_accepted = 0
        
        for i in tqdm(range(num_iters)):
            
            params_0 = {p: p.clone() for p in self.parameters()}
            momenta_0 = {p: torch.randn(p.shape) for p in self.parameters()}
            
            H_0 = self.hamiltonian(self.potential_energy(x_train, y_train), momenta_0)
                
            momenta = self.run_dynamics(momenta=momenta_0,
                                        x_train=x_train,
                                        y_train=y_train,
                                        epsilon=epsilon,
                                        num_steps=num_integral_steps)
            
            H = self.hamiltonian(self.potential_energy(x_train, y_train), momenta)
            
            threshold = np.minimum(1, torch.exp(H_0 - H).item())

            u = np.random.uniform(low=0., high=1.)
            is_accepted = threshold >= u

            if not is_accepted:
                for param in self.parameters():
                    param.data = params_0[param].data
            else:
                num_accepted += 1
            
            
            if i % log_every == 0:
                _, predicted = torch.max(self.forward(x_train.data), 1)
                accuracy = (predicted == y_train).sum().float() / y_train.size(0)

                print('Train accuracy {:.3}, accepted {} out of {}'.format(accuracy, num_accepted, i + 1))

In [247]:
hmc_net = HmcNet(input_dim=784,
                 num_classes=10,
                 num_units=100,
                 learn_rate=1e-2,
                 prior_sigma=1e2)

num_epochs = 100

x_train_, y_train_ = x_train[:600], y_train[:600]
x_test_, y_test_ = x_test[:100], y_test[:100]

hmc_net.fit_map(x_train, y_train, num_iterations=100, log_every=10)

hmc_net.hmc_sample(x_train=x_train_, y_train=y_train_, mixing_time=100, burn_in_time=10,
                   num_integral_steps=10, num_samples=5, epsilon=1e-4, log_every=10)

Loss 2.306, accuracy 0.144
Loss 2.015, accuracy 0.549
Loss 1.795, accuracy 0.689
Loss 1.751, accuracy 0.729
Loss 1.734, accuracy 0.746
Loss 1.723, accuracy 0.756
Loss 1.716, accuracy 0.765
Loss 1.708, accuracy 0.771
Loss 1.703, accuracy 0.776
Loss 1.712, accuracy 0.753


HBox(children=(IntProgress(value=0, max=510), HTML(value='')))

Train accuracy 0.787, accepted 1 out of 1
Train accuracy 0.787, accepted 11 out of 11
Train accuracy 0.787, accepted 21 out of 21
Train accuracy 0.785, accepted 31 out of 31
Train accuracy 0.785, accepted 41 out of 41
Train accuracy 0.787, accepted 51 out of 51
Train accuracy 0.785, accepted 61 out of 61
Train accuracy 0.785, accepted 71 out of 71
Train accuracy 0.785, accepted 80 out of 81
Train accuracy 0.783, accepted 89 out of 91
Train accuracy 0.783, accepted 98 out of 101
Train accuracy 0.785, accepted 108 out of 111
Train accuracy 0.785, accepted 118 out of 121
Train accuracy 0.785, accepted 128 out of 131
Train accuracy 0.783, accepted 138 out of 141
Train accuracy 0.785, accepted 148 out of 151
Train accuracy 0.782, accepted 157 out of 161
Train accuracy 0.78, accepted 165 out of 171
Train accuracy 0.782, accepted 175 out of 181
Train accuracy 0.78, accepted 185 out of 191
Train accuracy 0.783, accepted 195 out of 201
Train accuracy 0.783, accepted 205 out of 211
Train accurac