# ResNet for MNIST in PyTorch

In [None]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time

import torch

from torch import nn, optim
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
from copy import copy, deepcopy
import numpy as np


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
def save_last_model(input_model):
  model_save_name = 'resnet_adadelta_last.pkl'
  path = F"/content/gdrive/My Drive/{model_save_name}" 
  torch.save(input_model, path)

In [None]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)


In [None]:
# def get_data_loaders(train_batch_size, val_batch_size):
#     mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
#     # add gaussian noise maybe
#     data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

#     train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
#                               batch_size=train_batch_size, shuffle=True)

#     val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
#                             batch_size=val_batch_size, shuffle=False)
#     return train_loader, val_loader
def getData(name='cifar10', train_bs=128, test_bs=1000):    
    if name == 'svhn':
        train_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='extra', download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='test', download=True,transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=test_bs, shuffle=False)
    if name == 'mnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'emnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=True, download=True, split='balanced',
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=False, split='balanced', transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'cifar10':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'cifar100':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR100(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'tinyimagenet':      
        normalize = transforms.Normalize(mean=[0.44785526394844055, 0.41693055629730225, 0.36942949891090393],
                                     std=[0.2928885519504547, 0.28230994939804077, 0.2889912724494934])
        train_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/train',
        transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, pin_memory=False)
        test_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/val',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_bs, shuffle=False)
    return train_loader, test_loader

In [None]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [None]:
def exp_lr_scheduler(epoch, optimizer, strategy='normal', decay_eff=0.1, decayEpoch=[]):
    """Decay learning rate by a factor of lr_decay every lr_decay_epoch epochs"""
    if strategy == 'normal':
        if epoch in decayEpoch:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= decay_eff
            print('New learning rate is: ', param_group['lr'])
    else:
        print('wrong strategy')
        raise ValueError('A very specific bad thing happened.')
    return optimizer

In [None]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 20

model = MnistResNet().to(device)
# train_loader, val_loader = get_data_loaders(256, 256)
train_loader, val_loader = getData(name='mnist', train_bs=128, test_bs=1000)


losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

batches = len(train_loader)
val_batches = len(val_loader)
# keep best model
accuracies=[]
best_accuracy = 0
best_model = deepcopy(model)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    # # lr decay
    # optimizer = exp_lr_scheduler(epoch, optimizer, decay_eff=0.1, decayEpoch=[15])

    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
    
    current_model_accuracy = sum(accuracy)/val_batches
    accuracies.append(current_model_accuracy)
    if current_model_accuracy > best_accuracy:
        best_model = deepcopy(model)
        best_accuracy=current_model_accuracy
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
    print('current_model_accuracy: ',current_model_accuracy)
    print('best_accuracy: ',best_accuracy)

save_last_model(model)
model_save_name = 'resnet_adadelta_best.pkl'
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save(best_model, path)

print(losses)
print(f"Training time: {time.time()-start_ts}s")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...



Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 1/20, training loss: 1.5248583870401768, validation loss: 1.4958022832870483
	     precision: 0.9652
	        recall: 0.9647
	            F1: 0.9643
	      accuracy: 0.9647
current_model_accuracy:  0.9647
best_accuracy:  0.9647


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 2/20, training loss: 1.4879047982474125, validation loss: 1.4998568296432495
	     precision: 0.9638
	        recall: 0.9610
	            F1: 0.9611
	      accuracy: 0.9613
current_model_accuracy:  0.9612999999999999
best_accuracy:  0.9647


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 3/20, training loss: 1.481949538564377, validation loss: 1.4782482385635376
	     precision: 0.9829
	        recall: 0.9829
	            F1: 0.9827
	      accuracy: 0.9829
current_model_accuracy:  0.9828999999999999
best_accuracy:  0.9828999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 4/20, training loss: 1.4778690002620347, validation loss: 1.4750404357910156
	     precision: 0.9859
	        recall: 0.9854
	            F1: 0.9855
	      accuracy: 0.9856
current_model_accuracy:  0.9856000000000001
best_accuracy:  0.9856000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 5/20, training loss: 1.4763343402825948, validation loss: 1.4860590696334839
	     precision: 0.9769
	        recall: 0.9752
	            F1: 0.9753
	      accuracy: 0.9754
current_model_accuracy:  0.9753999999999999
best_accuracy:  0.9856000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 6/20, training loss: 1.4751628738984879, validation loss: 1.472697377204895
	     precision: 0.9887
	        recall: 0.9885
	            F1: 0.9885
	      accuracy: 0.9885
current_model_accuracy:  0.9884999999999999
best_accuracy:  0.9884999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 7/20, training loss: 1.473674733755685, validation loss: 1.474717140197754
	     precision: 0.9867
	        recall: 0.9866
	            F1: 0.9865
	      accuracy: 0.9866
current_model_accuracy:  0.9865999999999999
best_accuracy:  0.9884999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 8/20, training loss: 1.4725936864739033, validation loss: 1.474112868309021
	     precision: 0.9870
	        recall: 0.9868
	            F1: 0.9866
	      accuracy: 0.9869
current_model_accuracy:  0.9869
best_accuracy:  0.9884999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 9/20, training loss: 1.4717849782789185, validation loss: 1.4730228185653687
	     precision: 0.9882
	        recall: 0.9878
	            F1: 0.9879
	      accuracy: 0.9880
current_model_accuracy:  0.9879999999999999
best_accuracy:  0.9884999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 10/20, training loss: 1.4712313253488114, validation loss: 1.472001075744629
	     precision: 0.9897
	        recall: 0.9891
	            F1: 0.9893
	      accuracy: 0.9892
current_model_accuracy:  0.9892
best_accuracy:  0.9892


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 11/20, training loss: 1.4711609384906825, validation loss: 1.4703222513198853
	     precision: 0.9910
	        recall: 0.9909
	            F1: 0.9909
	      accuracy: 0.9909
current_model_accuracy:  0.9908999999999999
best_accuracy:  0.9908999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 12/20, training loss: 1.4709634651253218, validation loss: 1.4719109535217285
	     precision: 0.9896
	        recall: 0.9894
	            F1: 0.9894
	      accuracy: 0.9894
current_model_accuracy:  0.9894000000000001
best_accuracy:  0.9908999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 13/20, training loss: 1.4702969359944877, validation loss: 1.4714761972427368
	     precision: 0.9897
	        recall: 0.9897
	            F1: 0.9897
	      accuracy: 0.9897
current_model_accuracy:  0.9897
best_accuracy:  0.9908999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 14/20, training loss: 1.4694321821493381, validation loss: 1.4707406759262085
	     precision: 0.9906
	        recall: 0.9904
	            F1: 0.9905
	      accuracy: 0.9904
current_model_accuracy:  0.9904
best_accuracy:  0.9908999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 15/20, training loss: 1.470377054549992, validation loss: 1.472696304321289
	     precision: 0.9886
	        recall: 0.9883
	            F1: 0.9883
	      accuracy: 0.9884
current_model_accuracy:  0.9884000000000001
best_accuracy:  0.9908999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 16/20, training loss: 1.4686134024215405, validation loss: 1.4744398593902588
	     precision: 0.9871
	        recall: 0.9864
	            F1: 0.9865
	      accuracy: 0.9866
current_model_accuracy:  0.9866000000000001
best_accuracy:  0.9908999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 17/20, training loss: 1.4675631157116595, validation loss: 1.4692840576171875
	     precision: 0.9919
	        recall: 0.9918
	            F1: 0.9918
	      accuracy: 0.9918
current_model_accuracy:  0.9917999999999999
best_accuracy:  0.9917999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 18/20, training loss: 1.4679682475925764, validation loss: 1.4685872793197632
	     precision: 0.9926
	        recall: 0.9925
	            F1: 0.9925
	      accuracy: 0.9925
current_model_accuracy:  0.9925
best_accuracy:  0.9925


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 19/20, training loss: 1.468034388922425, validation loss: 1.4724019765853882
	     precision: 0.9887
	        recall: 0.9886
	            F1: 0.9886
	      accuracy: 0.9885
current_model_accuracy:  0.9884999999999999
best_accuracy:  0.9925


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 20/20, training loss: 1.468048906021281, validation loss: 1.4703518152236938
	     precision: 0.9910
	        recall: 0.9909
	            F1: 0.9909
	      accuracy: 0.9909
current_model_accuracy:  0.9908999999999999
best_accuracy:  0.9925
[1.5248583870401768, 1.4879047982474125, 1.481949538564377, 1.4778690002620347, 1.4763343402825948, 1.4751628738984879, 1.473674733755685, 1.4725936864739033, 1.4717849782789185, 1.4712313253488114, 1.4711609384906825, 1.4709634651253218, 1.4702969359944877, 1.4694321821493381, 1.470377054549992, 1.4686134024215405, 1.4675631157116595, 1.4679682475925764, 1.468034388922425, 1.468048906021281]
Training time: 695.1465153694153s


In [None]:
print('best_accuracy: ',best_accuracy)

best_accuracy:  0.9925


In [None]:
print(f"\t accuracy: {sum(accuracy)/val_batches:.4f}")

	 accuracy: 0.9909


In [None]:
current_model_accuracy = sum(accuracy)/val_batches

In [None]:
mod1 = torch.load(path)


In [None]:
val_batches

10

In [None]:
for name, scores in zip(("accuracy"), (accuracy)):
  print(name)


a
c
c
u
r
a
c
y


In [None]:
accuracy

[0.987, 0.99, 0.989, 0.986, 0.989, 0.991, 0.994, 0.996, 0.994, 0.993]