In [1]:
# Import libraries

import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
import sys
import torch
import torchvision

from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile

# PyTorch imports
import torch.nn.functional as F

from torch import nn
from torch.utils.data import RandomSampler
from torchvision import datasets, transforms

# Check what device to use
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available(): # Apple Silicon with PyTorch (nightly built) installed
    device = 'mps'
else:
    device = 'cpu'

print(f'PyTorch version: {torch.__version__}')
print(f'Using {device = }')

PyTorch version: 1.13.0.dev20220922
Using device = 'mps'


## Model architecture
We have to assume that this is provided and we will train a separate model for comparison.

In this analysis, we will just use the architecture to train two similar models to compare the weights, to determine if the approach is feasible.

In [3]:
# Model architecture
class CIFAR10Net(nn.Module):
    # from https://www.kaggle.com/code/shadabhussain/cifar-10-cnn-using-pytorch
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2) # output: 64 x 16 x 16

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2) # output: 128 x 8 x 8

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2) # output: 256 x 4 x 4

        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool1(x)

        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv5(x)
        x = F.relu(x)
        x = self.conv6(x)
        x = F.relu(x)
        x = self.pool3(x)

        x = torch.flatten(x, 1)        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        output = x
        return output

In [4]:
# Common functions
def load_model(model_class, name):
    model = model_class()
    model.load_state_dict(torch.load(name, map_location=torch.device(device)))

    return model

def download_and_unzip(url, extract_to='.'):
    http_response = urlopen(url)
    zipfile = ZipFile(BytesIO(http_response.read()))
    zipfile.extractall(path=extract_to)

def test(model, dataloader, loss_fn, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.to(device)
    model.eval()
    loss, correct = 0.0, 0
    
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)

            pred = model(x)
            loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.int).sum().item()
    
    loss /= num_batches
    accuracy = correct / size
    print('Test Result: Accuracy @ {:.2f}%, Avg loss @ {:.4f}\n'.format(100 * accuracy, loss))

    return loss, accuracy

In [5]:
# Download and unzip the TriggerImg
if not os.path.exists('TriggerImg'):
    TriggerImg_URL = 'https://bit.ly/cs612-TriggerImg'
    download_and_unzip(TriggerImg_URL)
else:
    print('Folder exists, skipping download')

In [6]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self,
        model,
        best_valid_loss=float('inf'), 
        save_file='model/best_model.pth'
    ):
        self.model = model
        self.best_valid_loss = best_valid_loss
        self.save_file = save_file
        
    def __call__(
        self, current_valid_loss
    ):
        if current_valid_loss < self.best_valid_loss:
            print(f'Saving best model with val loss = {current_valid_loss:.3f}')
            save_model(self.model, self.save_file)
            self.best_valid_loss = current_valid_loss

In [7]:
transform = transforms.ToTensor()

train_kwargs = {'batch_size': 100, 'shuffle':True}
test_kwargs = {'batch_size': 1000}
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
print ('Orig. training data seize:', len(trainset),'; Orig. test data size', len(testset))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


22.3%

In [None]:
#device='cuda'
MODELNAME='CIFAR10_10BD'
train_loader = torch.utils.data.DataLoader(trainset_new, **train_kwargs)
test_loader = torch.utils.data.DataLoader(testset, **test_kwargs)

model = CIFAR10Net().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
num_of_epochs = 30

# initialise save_best_model
save_best_model = SaveBestModel(model=model, save_file=os.path.join('model', f'best_model_{MODELNAME}.pt'))

for epoch in range(num_of_epochs):
    print('\n------------- Epoch {} -------------\n'.format(epoch))
    train(model, train_loader, nn.CrossEntropyLoss(), optimizer, device)
    val_loss = test(model, test_loader, nn.CrossEntropyLoss(), device)

    save_best_model(val_loss)
