# There and Back Again: A Collaborative Learning Demonstration on MNIST and rotated-MNIST

## The Basic Ingredients

- Loading the dataset
- Visualizing the images
- Defining the model
- Running the training loop
- Evaluating the model

## Loading the dataset

In [97]:
import torch
import torchvision
import torchvision.transforms as tf

transform = tf.Compose([
    tf.ToTensor(),
    tf.Normalize(0.5, 0.5),
])

rotated_transform = tf.Compose([
    tf.ToTensor(),
    tf.Normalize(0.5, 0.5),
    tf.RandomRotation(45)
])

batch_size = 16

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle=True)
train_iter = iter(trainloader)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle=True)
test_iter = iter(testloader)


rotated_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=rotated_transform)
rotated_trainloader = torch.utils.data.DataLoader(rotated_trainset, batch_size = batch_size, shuffle=True)
rotated_train_iter = iter(rotated_trainloader)

rotated_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=rotated_transform)
rotated_testloader = torch.utils.data.DataLoader(rotated_testset, batch_size = batch_size, shuffle=True)
rotated_test_iter = iter(rotated_testloader)

## Visualizing the dataset

In [75]:
import plotly.express as px

def imshow(im):
    if torch.min(im) < 0:
        im = im - torch.min(im)
        im = im / torch.max(im)


    ndims = len(im.shape)
    im = torch.squeeze(im)
    if len(im.shape) == ndims:
        im = im.permute(1,2,0)
    
    fig = px.imshow(im)

    fig.update_layout(margin={'t': 0, 'r': 0, 'b': 0, 'l': 0})

    if im.shape[0] == im.shape[1]:
        fig.update_layout(width=400, height=400)
    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    fig.show()

batch, labels = next(rotated_train_iter)
N = 16
imshow(torchvision.utils.make_grid(batch[:N, :, :, :], nrow = 4, normalize=True))
print([l.item() for l in labels])

[8, 7, 3, 5, 1, 1, 7, 7, 9, 4, 3, 2, 6, 2, 4, 8]


## Defining the Model

In [99]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module): # Input size = 32
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.conv2 = nn.Conv2d(3, 6, 5)
        self.max_pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(6 * 4 * 4, 80)
        self.fc2 = nn.Linear(80, 40)
        self.fc3 = nn.Linear(40, 10)

    def forward(self, x):
        x = self.max_pool(F.relu(self.conv1(x)))
        x = self.max_pool(F.relu(self.conv2(x)))

        x = torch.flatten(x, 1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x 
        

net = Net()

## Initial Model Training

In [123]:
def train(net, optimizer, dataloader, num_epochs=2):
    loss_fn = torch.nn.CrossEntropyLoss()

    for e in range(num_epochs):
        cumulative_loss = 0
        loss_window = len(dataloader) // 10
        for i, (batch, labels) in enumerate(dataloader, 1):
            optimizer.zero_grad()

            preds = net(batch)
            loss = loss_fn(preds, labels)
            loss.backward()
            optimizer.step()

            cumulative_loss += loss

            if i % loss_window == 0:
                print(f"Loss at epoch {e+1}, batch {i}: {cumulative_loss / loss_window :.2f}")
                cumulative_loss = 0

In [100]:
optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)
num_epochs = 2

train(net, optimizer, trainloader, num_epochs)

Loss at epoch 1, batch 375: 0.99
Loss at epoch 1, batch 750: 0.35
Loss at epoch 1, batch 1125: 0.26
Loss at epoch 1, batch 1500: 0.22
Loss at epoch 1, batch 1875: 0.20
Loss at epoch 1, batch 2250: 0.19
Loss at epoch 1, batch 2625: 0.18
Loss at epoch 1, batch 3000: 0.15
Loss at epoch 1, batch 3375: 0.14
Loss at epoch 1, batch 3750: 0.14
Loss at epoch 2, batch 375: 0.12
Loss at epoch 2, batch 750: 0.13
Loss at epoch 2, batch 1125: 0.11
Loss at epoch 2, batch 1500: 0.11
Loss at epoch 2, batch 1875: 0.11
Loss at epoch 2, batch 2250: 0.11
Loss at epoch 2, batch 2625: 0.09
Loss at epoch 2, batch 3000: 0.10
Loss at epoch 2, batch 3375: 0.10
Loss at epoch 2, batch 3750: 0.10


In [177]:
torch.onnx.export(net, batch, "mnist_test_pretrained.onnx", do_constant_folding=False)

## Doing some inference

In [20]:
import numpy as np

images, labels = next(test_iter)

imshow(torchvision.utils.make_grid(images))

print(list(torch.argmax(net(images), axis=1).numpy()))
print([l.item() for l in labels])

[0, 0, 3, 2, 6, 5, 9, 8, 1, 0, 6, 1, 1, 3, 2, 3]
[0, 0, 3, 2, 6, 5, 9, 8, 1, 9, 6, 1, 1, 3, 2, 3]


## Evaluating the Model

In [114]:
def evaluate_model(net, dataloader, class_specific_performance=False):
    if not class_specific_performance:
        correct, total = (0, 0)
    else:        
        correct = {digit: 0 for digit in range(10)}
        total = {digit: 0 for digit in range(10)}

    with torch.no_grad():
        for (batch, labels) in dataloader:
            outputs = net(batch)
            preds = torch.argmax(outputs, 1)

            if not class_specific_performance:
                total += labels.shape[0]
                correct += torch.sum(preds == labels).item()
            else:
                for label, pred in zip(labels, preds):                
                    if pred == label:
                        correct[label.item()] += 1
                    total[label.item()] += 1
    return correct, total

In [116]:
correct, total = evaluate_model(net, testloader)

print(f"The model classified {correct} digits correctly, out of {total} total, for an accuracy of {correct * 100 / total :.2f}%")

The model classified 9705 digits correctly, out of 10000 total, for an accuracy of 97.05%


## But how does it perform in the rotated context?

In [119]:
correct, total = evaluate_model(net, rotated_testloader, class_specific_performance=True)
c = sum(correct.values())
t = sum(total.values())
print(f"The model classified {c} rotated digits correctly, out of {t} total, for an accuracy of {c * 100 / t :.2f}%")

The model classified 7968 rotated digits correctly, out of 10000 total, for an accuracy of 79.68%


Open Question: Why does this change on each run?

In [120]:
digits, counts = list(zip(*correct.items()))
fig = px.bar(x = digits, y=np.array(counts)/1000, labels={'x': "digit", 'y': "Classification Success Rate"}, range_y=(0, 1))
fig.update_layout(yaxis={"tickformat": ",.0%"}, xaxis={"tickmode": "linear"})
fig.show()

## Now let's train the model again, this time using rotated data, and store the updates efficiently

In [178]:
import onnx
from onnx2pytorch import ConvertModel

onnx_model = onnx.load("mnist_test_pretrained.onnx")
imported_pytorch_model = ConvertModel(onnx_model, experimental=True)


Using experimental implementation that allows 'batch_size > 1'.Batchnorm layers could potentially produce false outputs.



Annoyingly, there is no native Pytorch implementation to import an ONNX model. However, I found an external library that attempts to do this, though it produces some strange results. Specifically, it adds several parameters to the model's state dict that I don't understand: they concern the weight and bias of the ReLUs, which I don't understand, since ReLUs have neither of those things.

In [112]:
print(imported_pytorch_model.state_dict().keys())
print("\n ----- \n")
print(net.state_dict().keys())

odict_keys(['_initializer_conv1_weight', '_initializer_conv1_bias', '_initializer_conv2_weight', '_initializer_conv2_bias', '_initializer_fc1_weight', '_initializer_fc1_bias', '_initializer_fc2_weight', '_initializer_fc2_bias', '_initializer_fc3_weight', '_initializer_fc3_bias', 'Conv_onnx::Relu_11.weight', 'Conv_onnx::Relu_11.bias', 'Conv_onnx::Relu_14.weight', 'Conv_onnx::Relu_14.bias', 'Gemm_onnx::Relu_18.weight', 'Gemm_onnx::Relu_18.bias', 'Gemm_onnx::Relu_20.weight', 'Gemm_onnx::Relu_20.bias', 'Gemm_22.weight', 'Gemm_22.bias'])

 ----- 

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


Nevertheless, we persist. Let's evaluate the imported model's performance and verify that it has similar performance to the original.

In [122]:
correct_normal, total_normal = evaluate_model(imported_pytorch_model, testloader)
correct_rotated, total_rotated = evaluate_model(imported_pytorch_model, rotated_testloader)

print(f"The imported model classified {correct_normal} digits correctly, out of {total_normal} total, for an accuracy of {correct_normal * 100 / total_normal :.2f}%")
print(f"The imported model classified {correct_rotated} rotated digits correctly, out of {total_rotated} total, for an accuracy of {correct_rotated * 100 / total_rotated :.2f}%")


Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.



The imported model classified 9705 digits correctly, out of 10000 total, for an accuracy of 97.05%
The imported model classified 8006 rotated digits correctly, out of 10000 total, for an accuracy of 80.06%


It does - great!

Now, there are two approaches we could take to non-dense fine tuning. To clarify terminology, we distinguish between a _parameter_ and a _parameter group_. A _parameter_ is a single scalar value that is trained and used by the model during inference to produce a result. When people refer to the size of large models (e.g. T5 has 11B parameters), they are referring to these scalar values.

A _parameter group_ is collection of parameters that has some conceptual and practical meaning as a unit of computation. Typical examples are the bias vector and the weight matrix of a linear layer, or the kernel of a 2D convolution. 

A typical training run (like the one we did at the start of this notebook) trains all possible parameters at every training step: every parameter within every parameter group. If you wanted to do less exhaustive training (and efficiently store the corresponding update), there are two natural approaches.

- Select some subset of the _parameter groups_, and only train them.
- Within each parameter group, select some subset of _parameters_, and only train them.

(You could, of course, combine both approaches).

To keep this test run implementation simple, we will take the second approach and train only the first convolutional layer, and the final dense layer.

## Unresolved Questions

How can a model's state dict have more keys that the model has parameters? Are the parameters not just the values in the state dict?

In [179]:
l = list(imported_pytorch_model.named_parameters())
[x[0] for x in l]

['Conv_onnx::Relu_11.weight',
 'Conv_onnx::Relu_11.bias',
 'Conv_onnx::Relu_14.weight',
 'Conv_onnx::Relu_14.bias',
 'Gemm_onnx::Relu_18.weight',
 'Gemm_onnx::Relu_18.bias',
 'Gemm_onnx::Relu_20.weight',
 'Gemm_onnx::Relu_20.bias',
 'Gemm_22.weight',
 'Gemm_22.bias']

In [180]:
[x.name for x in onnx_model.graph.initializer]

['conv1.weight',
 'conv1.bias',
 'conv2.weight',
 'conv2.bias',
 'fc1.weight',
 'fc1.bias',
 'fc2.weight',
 'fc2.bias',
 'fc3.weight',
 'fc3.bias']

In [181]:
net.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

In [127]:
l = list(imported_pytorch_model.parameters())
len(l)

10

In [128]:
l2 = list(net.parameters())
len(l2)

10

In [129]:
len(net.state_dict().keys())

10

In [130]:
len(imported_pytorch_model.state_dict().keys())

20