# ResNet for MNIST in PyTorch

In [1]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time

import torch

from torch import nn, optim
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
from copy import copy, deepcopy
import numpy as np


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
def save_last_model(input_model):
  model_save_name = 'resnet_adam_last.pkl'
  path = F"/content/gdrive/My Drive/{model_save_name}" 
  torch.save(input_model, path)

In [4]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)


In [5]:
# def get_data_loaders(train_batch_size, val_batch_size):
#     mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
#     # add gaussian noise maybe
#     data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

#     train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
#                               batch_size=train_batch_size, shuffle=True)

#     val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
#                             batch_size=val_batch_size, shuffle=False)
#     return train_loader, val_loader
def getData(name='cifar10', train_bs=128, test_bs=1000):    
    if name == 'svhn':
        train_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='extra', download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='test', download=True,transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=test_bs, shuffle=False)
    if name == 'mnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'emnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=True, download=True, split='balanced',
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=False, split='balanced', transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'cifar10':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'cifar100':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR100(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'tinyimagenet':      
        normalize = transforms.Normalize(mean=[0.44785526394844055, 0.41693055629730225, 0.36942949891090393],
                                     std=[0.2928885519504547, 0.28230994939804077, 0.2889912724494934])
        train_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/train',
        transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, pin_memory=False)
        test_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/val',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_bs, shuffle=False)
    return train_loader, test_loader

In [6]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [7]:
# optimizer = optim.Adam(model.parameters())


In [8]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 20

model = MnistResNet().to(device)
train_loader, val_loader = getData(name='mnist', train_bs=128, test_bs=1000)

losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

batches = len(train_loader)
val_batches = len(val_loader)
# keep best model
accuracies=[]
best_accuracy = 0
best_model = deepcopy(model)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    # # lr decay
    # optimizer = exp_lr_scheduler(epoch, optimizer, decay_eff=0.1, decayEpoch=[15])

    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
    
    current_model_accuracy = sum(accuracy)/val_batches
    accuracies.append(current_model_accuracy)
    if current_model_accuracy > best_accuracy:
        best_model = deepcopy(model)
        best_accuracy=current_model_accuracy
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
    print('current_model_accuracy: ',current_model_accuracy)
    print('best_accuracy: ',best_accuracy)

save_last_model(model)
model_save_name = 'resnet_adam_best.pkl'
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save(best_model, path)

print(losses)
print(f"Training time: {time.time()-start_ts}s")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)






HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 1/20, training loss: 1.5193396206857808, validation loss: 1.4956550598144531
	     precision: 0.9669
	        recall: 0.9656
	            F1: 0.9655
	      accuracy: 0.9659
current_model_accuracy:  0.9658999999999999
best_accuracy:  0.9658999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 2/20, training loss: 1.4916059721761674, validation loss: 1.5014790296554565
	     precision: 0.9629
	        recall: 0.9583
	            F1: 0.9588
	      accuracy: 0.9590
current_model_accuracy:  0.959
best_accuracy:  0.9658999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 3/20, training loss: 1.487216704435694, validation loss: 1.4855890274047852
	     precision: 0.9759
	        recall: 0.9754
	            F1: 0.9754
	      accuracy: 0.9757
current_model_accuracy:  0.9757
best_accuracy:  0.9757


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 4/20, training loss: 1.484091753390298, validation loss: 1.4831688404083252
	     precision: 0.9782
	        recall: 0.9775
	            F1: 0.9775
	      accuracy: 0.9775
current_model_accuracy:  0.9774999999999998
best_accuracy:  0.9774999999999998


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 5/20, training loss: 1.4823016291742386, validation loss: 1.4831833839416504
	     precision: 0.9785
	        recall: 0.9768
	            F1: 0.9772
	      accuracy: 0.9779
current_model_accuracy:  0.9778999999999998
best_accuracy:  0.9778999999999998


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 6/20, training loss: 1.4812402397330635, validation loss: 1.47854745388031
	     precision: 0.9824
	        recall: 0.9823
	            F1: 0.9821
	      accuracy: 0.9825
current_model_accuracy:  0.9825000000000002
best_accuracy:  0.9825000000000002


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 7/20, training loss: 1.4799750044401774, validation loss: 1.4803045988082886
	     precision: 0.9811
	        recall: 0.9807
	            F1: 0.9805
	      accuracy: 0.9808
current_model_accuracy:  0.9808
best_accuracy:  0.9825000000000002


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 8/20, training loss: 1.4780956036500585, validation loss: 1.4825654029846191
	     precision: 0.9797
	        recall: 0.9786
	            F1: 0.9785
	      accuracy: 0.9785
current_model_accuracy:  0.9785000000000001
best_accuracy:  0.9825000000000002


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 9/20, training loss: 1.4787105156668723, validation loss: 1.4774277210235596
	     precision: 0.9839
	        recall: 0.9839
	            F1: 0.9837
	      accuracy: 0.9838
current_model_accuracy:  0.9837999999999999
best_accuracy:  0.9837999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 10/20, training loss: 1.4785218617555176, validation loss: 1.4807907342910767
	     precision: 0.9806
	        recall: 0.9800
	            F1: 0.9799
	      accuracy: 0.9801
current_model_accuracy:  0.9801000000000002
best_accuracy:  0.9837999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 11/20, training loss: 1.4762756516938524, validation loss: 1.478922724723816
	     precision: 0.9827
	        recall: 0.9823
	            F1: 0.9822
	      accuracy: 0.9822
current_model_accuracy:  0.9822000000000001
best_accuracy:  0.9837999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 12/20, training loss: 1.4780448255762617, validation loss: 1.4800320863723755
	     precision: 0.9820
	        recall: 0.9809
	            F1: 0.9812
	      accuracy: 0.9813
current_model_accuracy:  0.9813000000000001
best_accuracy:  0.9837999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 13/20, training loss: 1.4770838591589857, validation loss: 1.4760349988937378
	     precision: 0.9852
	        recall: 0.9852
	            F1: 0.9850
	      accuracy: 0.9850
current_model_accuracy:  0.985
best_accuracy:  0.985


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 14/20, training loss: 1.4758679821038805, validation loss: 1.4749728441238403
	     precision: 0.9866
	        recall: 0.9866
	            F1: 0.9864
	      accuracy: 0.9862
current_model_accuracy:  0.9862
best_accuracy:  0.9862


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 15/20, training loss: 1.4742069564648528, validation loss: 1.4726499319076538
	     precision: 0.9886
	        recall: 0.9888
	            F1: 0.9886
	      accuracy: 0.9886
current_model_accuracy:  0.9885999999999999
best_accuracy:  0.9885999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 16/20, training loss: 1.4755148386904426, validation loss: 1.4798601865768433
	     precision: 0.9815
	        recall: 0.9811
	            F1: 0.9810
	      accuracy: 0.9814
current_model_accuracy:  0.9814
best_accuracy:  0.9885999999999999


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 17/20, training loss: 1.4742986098535533, validation loss: 1.4719055891036987
	     precision: 0.9891
	        recall: 0.9889
	            F1: 0.9889
	      accuracy: 0.9890
current_model_accuracy:  0.9890000000000001
best_accuracy:  0.9890000000000001


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 18/20, training loss: 1.4743892994008339, validation loss: 1.471240520477295
	     precision: 0.9902
	        recall: 0.9902
	            F1: 0.9901
	      accuracy: 0.9901
current_model_accuracy:  0.9900999999999998
best_accuracy:  0.9900999999999998


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 19/20, training loss: 1.4730010535925435, validation loss: 1.472740888595581
	     precision: 0.9885
	        recall: 0.9881
	            F1: 0.9882
	      accuracy: 0.9882
current_model_accuracy:  0.9882
best_accuracy:  0.9900999999999998


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 20/20, training loss: 1.4744583116665577, validation loss: 1.4747133255004883
	     precision: 0.9868
	        recall: 0.9867
	            F1: 0.9866
	      accuracy: 0.9866
current_model_accuracy:  0.9865999999999999
best_accuracy:  0.9900999999999998
[1.5193396206857808, 1.4916059721761674, 1.487216704435694, 1.484091753390298, 1.4823016291742386, 1.4812402397330635, 1.4799750044401774, 1.4780956036500585, 1.4787105156668723, 1.4785218617555176, 1.4762756516938524, 1.4780448255762617, 1.4770838591589857, 1.4758679821038805, 1.4742069564648528, 1.4755148386904426, 1.4742986098535533, 1.4743892994008339, 1.4730010535925435, 1.4744583116665577]
Training time: 696.9808831214905s


In [9]:
print('best_accuracy: ',best_accuracy)

best_accuracy:  0.9900999999999998


In [10]:
best_accuracy

0.9900999999999998