# MNIST example for LLVI
Code partly taken from https://nextjournal.com/gkoehler/pytorch-mnist

In [1]:
import torch
import torchvision

In [2]:
n_epochs = 3
batch_size_train = 32
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

### Load the data

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('../files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.kl import kl_divergence
from torch.distributions.categorical import Categorical

### Define the model

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2_mu = torch.randn(50, 10, requires_grad=True)
        self.fc2_log_var = torch.randn_like(self.fc2_mu, requires_grad=True)

    def forward(self, x, train=True):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        # x = F.dropout(x, training=self.training)
        x = x @ self.sample_fc_2(train = True)
        return x, self.fc2_mu, self.fc2_log_var

    def sample_fc_2(self, train):
        if train:
            return self.fc2_mu + torch.exp(0.5 * self.fc2_log_var) * torch.randn_like(self.fc2_mu)
        else:
            return self.fc2_mu


class CNN(nn.Module):
    # https://medium.com/@nutanbhogendrasharma/pytorch-convolutional-neural-network-with-mnist-dataset-4e8a4265e118
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )        # fully connected layer, output 10 classes
        self.fc_mu = torch.randn(32 * 7 * 7, 10, requires_grad=True)
        self.fc_log_var = torch.randn_like(self.fc_mu, requires_grad=True)
        # self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, train=True):
        x = self.conv1(x)
        x = self.conv2(x)        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = x @ self.sample_fc(train=train)
        output = F.log_softmax(output)
        return output, self.fc_mu, self.fc_log_var

    def sample_fc(self, train):
        if train:
            return self.fc_mu + torch.exp(0.5 * self.fc_log_var) * torch.randn_like(self.fc_mu)
        else:
            return self.fc_mu

In [6]:
network = Net()
optimizer = optim.SGD([{"params": network.parameters()}, {"params": network.fc2_mu}, {"params": network.fc2_log_var}], lr=learning_rate,
                      momentum=momentum)

In [7]:
print("Model's state_dict:")
for param_tensor in network.state_dict():
    print(param_tensor, "\t", network.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
conv1.weight 	 torch.Size([10, 1, 5, 5])
conv1.bias 	 torch.Size([10])
conv2.weight 	 torch.Size([20, 10, 5, 5])
conv2.bias 	 torch.Size([20])
fc1.weight 	 torch.Size([50, 320])
fc1.bias 	 torch.Size([50])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.01, 'momentum': 0.5, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5]}, {'lr': 0.01, 'momentum': 0.5, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [6]}, {'lr': 0.01, 'momentum': 0.5, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [7]}]


### Define loss and train function

In [8]:
def KL_div_gaussian_diagonal(mu1, log_var1, mu2, log_var2):
    return 0.5 * (torch.sum(log_var2) - torch.sum(log_var1) - mu1.shape[0] + torch.sum(torch.exp(log_var1 - log_var2)) + torch.sum(torch.div(torch.square(mu2 - mu1), torch.exp(log_var2))))

In [9]:
def NLL_loss(pred, target):
    return - torch.mean(F.log_softmax(pred, dim=-1)[target])


In [10]:
def loss_function(output, target, fc2_mu, fc2_log_var):
    fc2_mu = torch.flatten(fc2_mu)
    fc2_log_var = torch.flatten(fc2_log_var)
    # KL_Div = kl_divergence(MultivariateNormal(fc2_mu, torch.diag(torch.exp(fc2_log_var))), MultivariateNormal(torch.ones_like(fc2_mu), 0.5*torch.eye(fc2_mu.shape[0]))) # from pytorch
    # KL_Div = KL_div_gaussian_diagonal(fc2_mu, fc2_log_var, torch.zeros_like(fc2_mu),  - torch.ones_like(fc2_log_var)) # prior of mu = 0, log_var=0 -> var=1
    likelihood = F.cross_entropy(output, target, reduction="mean")
    # likelihood = NLL_loss(output, target)
    return likelihood #+ KL_Div 

In [11]:
def train(epochs):
  network.train()
  for epoch in range(epochs):
    episode_loss = []
    for batch_idx, (data, target) in enumerate(train_loader):
      optimizer.zero_grad()
      output, fc2_mu, fc2_log_var = network(data)
      loss = loss_function(output, target, fc2_mu, fc2_log_var)
      loss.backward()
      with torch.no_grad():
        episode_loss += [loss]
      optimizer.step()
    
    print(f"Epoch {epoch} loss", sum(episode_loss)/len(episode_loss))


In [12]:
network.fc2_mu[0]

tensor([ 0.0993, -0.7085,  0.4544, -0.6551,  1.0850,  0.2790,  0.0604, -0.2557,
         0.0298, -1.0999], grad_fn=<SelectBackward>)

In [13]:
train(5)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch 0 loss tensor(0.7843, grad_fn=<DivBackward0>)
Epoch 1 loss tensor(0.3485, grad_fn=<DivBackward0>)
Epoch 2 loss tensor(0.2509, grad_fn=<DivBackward0>)
Epoch 3 loss tensor(0.2115, grad_fn=<DivBackward0>)
Epoch 4 loss tensor(0.1807, grad_fn=<DivBackward0>)


In [14]:
network.fc2_mu[0]

tensor([-0.1245, -0.8598,  0.5264, -0.6987,  1.8208,  0.1970, -0.0564, -0.3261,
         0.0217, -1.2116], grad_fn=<SelectBackward>)

In [15]:
network.fc2_log_var[:10]

tensor([[ 0.6453, -0.4081, -0.8679,  0.4652, -0.2863,  0.6638, -1.7010,  0.4917,
         -0.9439, -0.7694],
        [ 0.5918, -0.7719, -0.9773,  0.9432,  0.5641, -0.7648, -1.3679,  0.0279,
         -1.0697, -0.8418],
        [-0.4528,  0.5255,  0.1406,  0.7993,  0.2916, -0.6659,  0.7117,  1.7143,
         -0.5608, -0.0209],
        [ 1.2192, -2.1407, -0.1426,  0.5612,  0.4068, -0.9461, -1.3702,  0.6492,
          0.9687,  0.7044],
        [ 1.3438, -1.0579,  0.8754, -0.7553, -1.3512,  0.7080,  0.3267,  0.0069,
         -0.1236,  0.2380],
        [ 0.4323, -0.8001, -0.5390,  0.1284,  0.4230,  0.4783, -0.1827,  0.7006,
          1.3789,  0.3679],
        [ 0.3035, -0.4362, -1.7527, -0.6813, -0.0329, -1.2927, -0.6030, -0.3196,
          0.4540, -0.0481],
        [-0.1625,  0.5907, -0.2997,  1.2326, -2.0085, -1.2568, -0.8514, -1.4775,
          1.4476, -2.3986],
        [ 0.4222,  0.1706,  0.4013,  0.9086, -1.8218, -0.4644, -1.3920,  0.9575,
         -0.6742, -1.1100],
        [-0.5251, -

### Testing

In [16]:
test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('../files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [17]:
test_losses = []
def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output, _, _ = network(data, train=False)
      test_loss += F.cross_entropy(output, target, reduction="sum").item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [18]:
test()


Test set: Avg. loss: 0.0739, Accuracy: 9795/10000 (98%)

