# ResNet for MNIST in PyTorch

In [1]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time

import torch

from torch import nn, optim
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
from copy import copy, deepcopy
import numpy as np


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
def save_last_model(input_model):
  model_save_name = 'resnet_adam_last.pkl'
  path = F"/content/gdrive/My Drive/{model_save_name}" 
  torch.save(input_model, path)

In [4]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)


In [5]:
# def get_data_loaders(train_batch_size, val_batch_size):
#     mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
#     # add gaussian noise maybe
#     data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

#     train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
#                               batch_size=train_batch_size, shuffle=True)

#     val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
#                             batch_size=val_batch_size, shuffle=False)
#     return train_loader, val_loader
def getData(name='cifar10', train_bs=128, test_bs=1000):    
    if name == 'svhn':
        train_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='extra', download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='test', download=True,transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=test_bs, shuffle=False)
    if name == 'mnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'emnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=True, download=True, split='balanced',
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=False, split='balanced', transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'cifar10':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'cifar100':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR100(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'tinyimagenet':      
        normalize = transforms.Normalize(mean=[0.44785526394844055, 0.41693055629730225, 0.36942949891090393],
                                     std=[0.2928885519504547, 0.28230994939804077, 0.2889912724494934])
        train_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/train',
        transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, pin_memory=False)
        test_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/val',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_bs, shuffle=False)
    return train_loader, test_loader

In [6]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [7]:
# optimizer = optim.Adam(model.parameters())


In [8]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 20

model = MnistResNet().to(device)
train_loader, val_loader = getData(name='mnist', train_bs=128, test_bs=1000)

losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

batches = len(train_loader)
val_batches = len(val_loader)
# keep best model
best_accuracy = 0
best_model = deepcopy(model)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    # # lr decay
    # optimizer = exp_lr_scheduler(epoch, optimizer, decay_eff=0.1, decayEpoch=[15])

    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
    
    current_model_accuracy = sum(accuracy)/val_batches
    if current_model_accuracy > best_accuracy:
        best_model = deepcopy(model)
        best_accuracy=current_model_accuracy
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
    print('current_model_accuracy: ',current_model_accuracy)
    print('best_accuracy: ',best_accuracy)

save_last_model(model)
model_save_name = 'resnet_adam_best.pkl'
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save(best_model, path)

print(losses)
print(f"Training time: {time.time()-start_ts}s")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)






HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 1/20, training loss: 1.5151911262255997, validation loss: 1.5034693479537964
	     precision: 0.9603
	        recall: 0.9574
	            F1: 0.9575
	      accuracy: 0.9574
current_model_accuracy:  0.9574
best_accuracy:  0.9574


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 2/20, training loss: 1.4911684026596135, validation loss: 1.484732985496521
	     precision: 0.9764
	        recall: 0.9763
	            F1: 0.9761
	      accuracy: 0.9762
current_model_accuracy:  0.9762000000000001
best_accuracy:  0.9762000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 3/20, training loss: 1.4854977270687566, validation loss: 1.4854204654693604
	     precision: 0.9768
	        recall: 0.9754
	            F1: 0.9755
	      accuracy: 0.9755
current_model_accuracy:  0.9755
best_accuracy:  0.9762000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 4/20, training loss: 1.4816188108183936, validation loss: 1.4774001836776733
	     precision: 0.9837
	        recall: 0.9836
	            F1: 0.9834
	      accuracy: 0.9835
current_model_accuracy:  0.9834999999999999
best_accuracy:  0.9834999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 5/20, training loss: 1.4817188785020223, validation loss: 1.4799529314041138
	     precision: 0.9823
	        recall: 0.9814
	            F1: 0.9815
	      accuracy: 0.9814
current_model_accuracy:  0.9814
best_accuracy:  0.9834999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 6/20, training loss: 1.4803039336255364, validation loss: 1.4790834188461304
	     precision: 0.9823
	        recall: 0.9818
	            F1: 0.9818
	      accuracy: 0.9820
current_model_accuracy:  0.9819999999999999
best_accuracy:  0.9834999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 7/20, training loss: 1.4792204012494605, validation loss: 1.478491187095642
	     precision: 0.9827
	        recall: 0.9827
	            F1: 0.9825
	      accuracy: 0.9827
current_model_accuracy:  0.9827
best_accuracy:  0.9834999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 8/20, training loss: 1.4784836095533391, validation loss: 1.4752000570297241
	     precision: 0.9863
	        recall: 0.9862
	            F1: 0.9860
	      accuracy: 0.9860
current_model_accuracy:  0.986
best_accuracy:  0.986


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 9/20, training loss: 1.4774120258115757, validation loss: 1.4775526523590088
	     precision: 0.9838
	        recall: 0.9837
	            F1: 0.9836
	      accuracy: 0.9837
current_model_accuracy:  0.9837
best_accuracy:  0.986


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 10/20, training loss: 1.4772230935757602, validation loss: 1.4758200645446777
	     precision: 0.9856
	        recall: 0.9856
	            F1: 0.9855
	      accuracy: 0.9854
current_model_accuracy:  0.9854
best_accuracy:  0.986


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 11/20, training loss: 1.4758020331864672, validation loss: 1.471098780632019
	     precision: 0.9901
	        recall: 0.9901
	            F1: 0.9900
	      accuracy: 0.9900
current_model_accuracy:  0.9899999999999999
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 12/20, training loss: 1.4774872613614047, validation loss: 1.4770170450210571
	     precision: 0.9846
	        recall: 0.9841
	            F1: 0.9842
	      accuracy: 0.9842
current_model_accuracy:  0.9842000000000001
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 13/20, training loss: 1.4751913326381365, validation loss: 1.4779428243637085
	     precision: 0.9837
	        recall: 0.9831
	            F1: 0.9832
	      accuracy: 0.9832
current_model_accuracy:  0.9831999999999999
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 14/20, training loss: 1.476731881912329, validation loss: 1.4747847318649292
	     precision: 0.9862
	        recall: 0.9860
	            F1: 0.9859
	      accuracy: 0.9863
current_model_accuracy:  0.9863
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 15/20, training loss: 1.4756032098839278, validation loss: 1.473901629447937
	     precision: 0.9874
	        recall: 0.9868
	            F1: 0.9870
	      accuracy: 0.9872
current_model_accuracy:  0.9872
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 16/20, training loss: 1.4733608013022936, validation loss: 1.4764769077301025
	     precision: 0.9849
	        recall: 0.9845
	            F1: 0.9845
	      accuracy: 0.9846
current_model_accuracy:  0.9846
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 17/20, training loss: 1.4737816759264037, validation loss: 1.4770770072937012
	     precision: 0.9839
	        recall: 0.9837
	            F1: 0.9836
	      accuracy: 0.9839
current_model_accuracy:  0.9839
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 18/20, training loss: 1.4742182856683792, validation loss: 1.4756799936294556
	     precision: 0.9852
	        recall: 0.9854
	            F1: 0.9852
	      accuracy: 0.9853
current_model_accuracy:  0.9853000000000002
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 19/20, training loss: 1.4717551190207507, validation loss: 1.4716333150863647
	     precision: 0.9892
	        recall: 0.9892
	            F1: 0.9891
	      accuracy: 0.9892
current_model_accuracy:  0.9892
best_accuracy:  0.9899999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 20/20, training loss: 1.4740877489545452, validation loss: 1.4738938808441162
	     precision: 0.9873
	        recall: 0.9874
	            F1: 0.9873
	      accuracy: 0.9872
current_model_accuracy:  0.9872000000000002
best_accuracy:  0.9899999999999999
[1.5151911262255997, 1.4911684026596135, 1.4854977270687566, 1.4816188108183936, 1.4817188785020223, 1.4803039336255364, 1.4792204012494605, 1.4784836095533391, 1.4774120258115757, 1.4772230935757602, 1.4758020331864672, 1.4774872613614047, 1.4751913326381365, 1.476731881912329, 1.4756032098839278, 1.4733608013022936, 1.4737816759264037, 1.4742182856683792, 1.4717551190207507, 1.4740877489545452]
Training time: 723.2948718070984s


In [9]:
print('best_accuracy: ',best_accuracy)

best_accuracy:  0.9899999999999999


In [10]:
best_accuracy

0.9899999999999999