### Imports required libraries

In [None]:
import matplotlib.pyplot as plt
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


# function to count number of parameters
def get_n_params(model):
    np = 0
    for p in list(model.parameters()):
        np += p.nelement()
    return np


plt.style.use(["dark_background", "bmh"])
plt.rc("axes", facecolor="k")
plt.rc("figure", facecolor="k")
plt.rc("figure", figsize=(10, 10), dpi=100)

In [None]:
# Check if we have gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
model_fnn.parameters()

### The Dataset

In [None]:
data_dir = "../../data"

In [None]:
input_size = 28 * 28  # images are 28x28 pixels
output_size = 10  # there are 10 classes

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        data_dir,
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.1307,), (0.3081,)
                ),  # mean and std of the MNIST training set
            ]
        ),
    ),
    batch_size=64,
    shuffle=True,
)


test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        data_dir,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=1000,
    shuffle=True,
)

In [None]:
plt.figure(figsize=(16, 6))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    image, _ = train_loader.dataset.__getitem__(i)
    plt.imshow(image.squeeze().numpy())
    plt.axis("off");

### Modelling

To compare the differences we are going to create two models:
- fully connected model only using linear layers
- CNN model which has convolutions and max pooling 

Cool animations:
- https://github.com/vdumoulin/conv_arithmetic
- https://www.reddit.com/r/manim/comments/ge19xj/a_simple_animation_to_show_how_max_pooling_works/

In [None]:
class FC2Layer(nn.Module):
    def __init__(self, input_size, n_hidden, output_size):
        super().__init__()
        self.input_size = input_size

        self.network = nn.Sequential(
            nn.Linear(input_size, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, output_size),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        x = x.view(-1, self.input_size)
        return self.network(x)


class CNN(nn.Module):
    def __init__(self, input_size, n_feature, output_size):
        super().__init__()
        self.n_feature = n_feature
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_feature, kernel_size=5)
        self.conv2 = nn.Conv2d(n_feature, n_feature, kernel_size=5)
        self.fc1 = nn.Linear(n_feature * 4 * 4, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x, verbose=False):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2)
        x = x.view(-1, self.n_feature * 4 * 4)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

### Train and Test

We are using pytorch as neural networks framework. Before training we need to set the data, model, loss function and optimizer. 

Afterwards, there are 5 steps to train a neural network for each epoch:
1. get the data --> every data instance should be a pair of input + label/target
2. zero the gradients for every batch --> pytorch accumulates the gradients on subsequent backward passes [^1]
3. make predictions --> pass input data to the model and get the predictions
4. comput loss and gradientes --> with the newly predictions we will compute the loss and comput the gradients
5. adjusts the learning weights --> after the computation of the weights we need to adjust the weights

[Training Pytorch documentation](https://pytorch.org/tutorials/beginner/introyt/trainingyt.html)

[^1]: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch 

In [None]:
def train_one_epoch(model, optimizer, loss_fn):
    for batch_idx, (data, target) in enumerate(train_loader):
        # get data
        data, target = data.to(device), target.to(device)
        # zero the gradients
        optimizer.zero_grad()
        # get predictions
        output = model(data)
        # compute the loss and it gradients
        loss = loss_fn(output, target)
        loss.backward()
        # adjust the weights
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )

In [None]:
accuracy_list = []

In [None]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        # send to device
        data, target = data.to(device), target.to(device)

        # permute pixels
        output = model(data)

        test_loss += F.nll_loss(
            output, target, reduction="sum"
        ).item()  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[
            1
        ]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / len(test_loader.dataset)
    accuracy_list.append(accuracy)
    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), accuracy
        )
    )

In [None]:
model_fnn.parameters

#### Train/Test Fully connect

In [None]:
n_hidden = 8  # number of hidden units

model_fnn = FC2Layer(input_size, n_hidden, output_size)  # define the model
model_fnn.to(device)  # move it to the correct device

loss_fn = torch.nn.NLLLoss()  # select the loss according to our problem
optimizer = optim.SGD(model_fnn.parameters(), lr=0.01, momentum=0.5)

print(f"Number of parameters: {get_n_params(model_fnn)}")

for epoch in range(0, 1):
    train_one_epoch(model_fnn, optimizer, loss_fn)

test(model_fnn)

#### Train/Test CNN

In [None]:
# Training settings
n_features = 6  # number of feature maps

model_cnn = CNN(input_size, n_features, output_size)
model_cnn.to(device)

loss_fn = torch.nn.NLLLoss()  # select the loss according to our problem

optimizer = optim.SGD(
    model_cnn.parameters(), lr=0.01, momentum=0.5
)  # select the optimizer
print(f"Number of parameters: {get_n_params(model_cnn)}")

for epoch in range(0, 1):
    train_one_epoch(model_cnn, optimizer, loss_fn)

test(model_cnn)

#### Some experiments
Lets try to understand why cnn is better than fc

In [None]:
perm = torch.randperm(784)
plt.figure(figsize=(16, 12))
for i in range(10):
    image, _ = train_loader.dataset.__getitem__(i)
    # permute pixels
    image_perm = image.view(-1, 28 * 28).clone()
    image_perm = image_perm[:, perm]
    image_perm = image_perm.view(-1, 1, 28, 28)
    plt.subplot(4, 5, i + 1)
    plt.imshow(image.squeeze().numpy())
    plt.axis("off")
    plt.subplot(4, 5, i + 11)
    plt.imshow(image_perm.squeeze().numpy())
    plt.axis("off")

In [None]:
# train the model with permutation
perm = torch.randperm(784)


def train_one_epoch_perm(model, optimizer, loss_fn):
    for batch_idx, (data, target) in enumerate(train_loader):
        # get data
        data, target = data.to(device), target.to(device)
        data = data.view(-1, 28 * 28)
        data = data[:, perm]
        data = data.view(-1, 1, 28, 28)

        # zero the gradients
        optimizer.zero_grad()
        # get predictions
        output = model(data)
        # compute the loss and it gradients
        loss = loss_fn(output, target)
        loss.backward()
        # adjust the weights
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )


# test the model with permutation
def test_perm(model):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        # send to device
        data, target = data.to(device), target.to(device)

        # permute pixels
        data = data.view(-1, 28 * 28)
        data = data[:, perm]
        data = data.view(-1, 1, 28, 28)

        output = model(data)

        test_loss += F.nll_loss(
            output, target, reduction="sum"
        ).item()  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[
            1
        ]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / len(test_loader.dataset)
    accuracy_list.append(accuracy)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), accuracy
        )
    )

#### Train/Test CNN with permutation 

In [None]:
n_hidden = 8  # number of hidden units

model_fnn_perm = FC2Layer(input_size, n_hidden, output_size)  # define the model
model_fnn_perm.to(device)  # move it to the correct device

loss_fn = torch.nn.NLLLoss()  # select the loss according to our problem
optimizer = optim.SGD(model_fnn_perm.parameters(), lr=0.01, momentum=0.5)

print(f"Number of parameters: {get_n_params(model_fnn_perm)}")

for epoch in range(0, 1):
    train_one_epoch_perm(model_fnn, optimizer, loss_fn)

test_perm(model_fnn)

In [None]:
# Training settings
n_features = 6  # number of feature maps

model_cnn = CNN(input_size, n_features, output_size)
model_cnn.to(device)
optimizer = optim.SGD(model_cnn.parameters(), lr=0.01, momentum=0.5)
print(f"Number of parameters: {get_n_params(model_cnn)}")

for epoch in range(0, 1):
    train_one_epoch_perm(model_cnn, optimizer, loss_fn)

test_perm(model_cnn)

In [None]:
plt.bar(
    ("NN image", "CNN image", "CNN scrambled", "NN scrambled"),
    accuracy_list[:4],
    width=0.4,
)
plt.ylim((min(accuracy_list[:4]) - 5, 96))
plt.ylabel("Accuracy [%]");

##### The ConvNet's performance drops when we permute the pixels, but the Fully-Connected Network's performance stays the same. WHY?