# ResNet for MNIST in PyTorch

In [1]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time

import torch

from torch import nn, optim
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
from copy import copy, deepcopy
import numpy as np


In [2]:
# model = models.resnet152(pretrained=True)
# for param in model.parameters():
#     param.requires_grad = False

In [3]:

from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [4]:
def save_last_model(input_model):
  model_save_name = 'resnet_sgd_last.pkl'
  path = F"/content/gdrive/My Drive/{model_save_name}" 
  torch.save(input_model, path)

In [5]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)


In [6]:
# def get_data_loaders(train_batch_size, val_batch_size):
#     mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
#     # add gaussian noise maybe
#     data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

#     train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
#                               batch_size=train_batch_size, shuffle=True)

#     val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
#                             batch_size=val_batch_size, shuffle=False)
#     return train_loader, val_loader
def getData(name='cifar10', train_bs=128, test_bs=1000):    
    if name == 'svhn':
        train_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='extra', download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
    datasets.SVHN('../data', split='test', download=True,transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=test_bs, shuffle=False)
    if name == 'mnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'emnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=True, download=True, split='balanced',
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=train_bs, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.EMNIST('../data', train=False, split='balanced', transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1751,), (0.3267,))
                           ])),
            batch_size=test_bs, shuffle=False)
    if name == 'cifar10':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'cifar100':
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        trainset = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_bs, shuffle=True)
        testset = datasets.CIFAR100(root='../data', train=False, download=False, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False)
    if name == 'tinyimagenet':      
        normalize = transforms.Normalize(mean=[0.44785526394844055, 0.41693055629730225, 0.36942949891090393],
                                     std=[0.2928885519504547, 0.28230994939804077, 0.2889912724494934])
        train_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/train',
        transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, pin_memory=False)
        test_dataset = datasets.ImageFolder(
        '../data/tiny-imagenet-200/val',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_bs, shuffle=False)
    return train_loader, test_loader

In [7]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [8]:
def exp_lr_scheduler(epoch, optimizer, strategy='normal', decay_eff=0.1, decayEpoch=[]):
    """Decay learning rate by a factor of lr_decay every lr_decay_epoch epochs"""
    if strategy == 'normal':
        if epoch in decayEpoch:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= decay_eff
            print('New learning rate is: ', param_group['lr'])
    else:
        print('wrong strategy')
        raise ValueError('A very specific bad thing happened.')
    return optimizer

In [11]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 90

model = MnistResNet().to(device)
train_loader, val_loader = getData(name='mnist', train_bs=128, test_bs=1000)

losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

batches = len(train_loader)
val_batches = len(val_loader)

# keep best model
accuracies=[]
best_accuracy = 0
best_model = deepcopy(model)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    # lr decay
    optimizer = exp_lr_scheduler(epoch, optimizer, decay_eff=0.1, decayEpoch=[30,60,80])


    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
    
    current_model_accuracy = sum(accuracy)/val_batches
    accuracies.append(current_model_accuracy)
    if current_model_accuracy > best_accuracy:
        best_model = deepcopy(model)
        best_accuracy=current_model_accuracy
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
save_last_model(model)
model_save_name = 'resnet_sgd_best.pkl'
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save(best_model, path)

print(losses)
print(f"Training time: {time.time()-start_ts}s")

HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 1/90, training loss: 1.5392260045639232, validation loss: 1.4858535528182983
	     precision: 0.9786
	        recall: 0.9775
	            F1: 0.9778
	      accuracy: 0.9781


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 2/90, training loss: 1.4791650932226608, validation loss: 1.4763234853744507
	     precision: 0.9864
	        recall: 0.9861
	            F1: 0.9861
	      accuracy: 0.9862


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 3/90, training loss: 1.4729896313600195, validation loss: 1.4754246473312378
	     precision: 0.9865
	        recall: 0.9865
	            F1: 0.9864
	      accuracy: 0.9865


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 4/90, training loss: 1.4709592025671432, validation loss: 1.47457754611969
	     precision: 0.9869
	        recall: 0.9869
	            F1: 0.9868
	      accuracy: 0.9869


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 5/90, training loss: 1.4686612203431282, validation loss: 1.4750540256500244
	     precision: 0.9871
	        recall: 0.9868
	            F1: 0.9868
	      accuracy: 0.9869


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 6/90, training loss: 1.4674271462060242, validation loss: 1.4744255542755127
	     precision: 0.9873
	        recall: 0.9871
	            F1: 0.9871
	      accuracy: 0.9872


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 7/90, training loss: 1.4666151914006866, validation loss: 1.472842812538147
	     precision: 0.9887
	        recall: 0.9884
	            F1: 0.9885
	      accuracy: 0.9886


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 8/90, training loss: 1.4659180026064549, validation loss: 1.4727015495300293
	     precision: 0.9889
	        recall: 0.9885
	            F1: 0.9886
	      accuracy: 0.9887


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 9/90, training loss: 1.4653810857455614, validation loss: 1.4713109731674194
	     precision: 0.9909
	        recall: 0.9906
	            F1: 0.9907
	      accuracy: 0.9907


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 10/90, training loss: 1.4646790678313037, validation loss: 1.4728161096572876
	     precision: 0.9884
	        recall: 0.9882
	            F1: 0.9882
	      accuracy: 0.9883


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 11/90, training loss: 1.4649907843644685, validation loss: 1.4723763465881348
	     precision: 0.9888
	        recall: 0.9888
	            F1: 0.9887
	      accuracy: 0.9888


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 12/90, training loss: 1.4641684446253502, validation loss: 1.470502257347107
	     precision: 0.9911
	        recall: 0.9910
	            F1: 0.9910
	      accuracy: 0.9910


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 13/90, training loss: 1.4636683212414479, validation loss: 1.472190260887146
	     precision: 0.9892
	        recall: 0.9890
	            F1: 0.9890
	      accuracy: 0.9891


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 14/90, training loss: 1.4632786545417964, validation loss: 1.471271276473999
	     precision: 0.9903
	        recall: 0.9902
	            F1: 0.9901
	      accuracy: 0.9902


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 15/90, training loss: 1.4631749101793334, validation loss: 1.4709981679916382
	     precision: 0.9900
	        recall: 0.9900
	            F1: 0.9899
	      accuracy: 0.9901


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 16/90, training loss: 1.463330558368138, validation loss: 1.4698328971862793
	     precision: 0.9917
	        recall: 0.9916
	            F1: 0.9916
	      accuracy: 0.9917


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 17/90, training loss: 1.462880331824329, validation loss: 1.4701682329177856
	     precision: 0.9912
	        recall: 0.9912
	            F1: 0.9912
	      accuracy: 0.9912


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 18/90, training loss: 1.4628208179209532, validation loss: 1.4699972867965698
	     precision: 0.9916
	        recall: 0.9914
	            F1: 0.9914
	      accuracy: 0.9915


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 19/90, training loss: 1.4628090777122644, validation loss: 1.4692349433898926
	     precision: 0.9923
	        recall: 0.9922
	            F1: 0.9922
	      accuracy: 0.9923


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 20/90, training loss: 1.4626344261901465, validation loss: 1.4687113761901855
	     precision: 0.9930
	        recall: 0.9929
	            F1: 0.9929
	      accuracy: 0.9930


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 21/90, training loss: 1.462607650360319, validation loss: 1.468794822692871
	     precision: 0.9926
	        recall: 0.9925
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 22/90, training loss: 1.4625243362841576, validation loss: 1.4691849946975708
	     precision: 0.9924
	        recall: 0.9923
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 23/90, training loss: 1.4625405258715534, validation loss: 1.4692082405090332
	     precision: 0.9924
	        recall: 0.9923
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 24/90, training loss: 1.4625472956374763, validation loss: 1.4709042310714722
	     precision: 0.9905
	        recall: 0.9904
	            F1: 0.9904
	      accuracy: 0.9905


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 25/90, training loss: 1.4624905293938448, validation loss: 1.469589114189148
	     precision: 0.9921
	        recall: 0.9920
	            F1: 0.9920
	      accuracy: 0.9921


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 26/90, training loss: 1.4623249881048954, validation loss: 1.468685269355774
	     precision: 0.9927
	        recall: 0.9927
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 27/90, training loss: 1.4625074672800646, validation loss: 1.4698225259780884
	     precision: 0.9914
	        recall: 0.9912
	            F1: 0.9912
	      accuracy: 0.9914


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 28/90, training loss: 1.4622607653074935, validation loss: 1.4692376852035522
	     precision: 0.9923
	        recall: 0.9921
	            F1: 0.9922
	      accuracy: 0.9923


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 29/90, training loss: 1.46216947183426, validation loss: 1.4693608283996582
	     precision: 0.9920
	        recall: 0.9918
	            F1: 0.9919
	      accuracy: 0.9920


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 30/90, training loss: 1.4621133852615031, validation loss: 1.469665765762329
	     precision: 0.9918
	        recall: 0.9917
	            F1: 0.9917
	      accuracy: 0.9918


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…

New learning rate is:  0.001

Epoch 31/90, training loss: 1.4620711572134673, validation loss: 1.4693423509597778
	     precision: 0.9921
	        recall: 0.9920
	            F1: 0.9920
	      accuracy: 0.9921


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 32/90, training loss: 1.4620731756377068, validation loss: 1.4691877365112305
	     precision: 0.9919
	        recall: 0.9918
	            F1: 0.9918
	      accuracy: 0.9919


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 33/90, training loss: 1.462037784712655, validation loss: 1.4690334796905518
	     precision: 0.9924
	        recall: 0.9923
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 34/90, training loss: 1.4620120962545562, validation loss: 1.4687904119491577
	     precision: 0.9923
	        recall: 0.9922
	            F1: 0.9922
	      accuracy: 0.9923


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 35/90, training loss: 1.4620212237718009, validation loss: 1.4687341451644897
	     precision: 0.9923
	        recall: 0.9923
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 36/90, training loss: 1.462016803369339, validation loss: 1.468915343284607
	     precision: 0.9925
	        recall: 0.9925
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 37/90, training loss: 1.4619982623850614, validation loss: 1.4688791036605835
	     precision: 0.9920
	        recall: 0.9919
	            F1: 0.9919
	      accuracy: 0.9920


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 38/90, training loss: 1.4619991568359993, validation loss: 1.4686816930770874
	     precision: 0.9927
	        recall: 0.9925
	            F1: 0.9925
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 39/90, training loss: 1.4619889297465016, validation loss: 1.468626618385315
	     precision: 0.9927
	        recall: 0.9927
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 40/90, training loss: 1.4619848987441073, validation loss: 1.4688142538070679
	     precision: 0.9921
	        recall: 0.9920
	            F1: 0.9920
	      accuracy: 0.9921


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 41/90, training loss: 1.4619933018552216, validation loss: 1.46855628490448
	     precision: 0.9927
	        recall: 0.9925
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 42/90, training loss: 1.4619962454858872, validation loss: 1.4687303304672241
	     precision: 0.9926
	        recall: 0.9925
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 43/90, training loss: 1.4619800152300773, validation loss: 1.4684480428695679
	     precision: 0.9932
	        recall: 0.9931
	            F1: 0.9931
	      accuracy: 0.9932


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 44/90, training loss: 1.4619673574402896, validation loss: 1.4686148166656494
	     precision: 0.9928
	        recall: 0.9928
	            F1: 0.9927
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 45/90, training loss: 1.461963925534474, validation loss: 1.468634009361267
	     precision: 0.9928
	        recall: 0.9927
	            F1: 0.9927
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 46/90, training loss: 1.4619814655674037, validation loss: 1.4687540531158447
	     precision: 0.9929
	        recall: 0.9928
	            F1: 0.9928
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 47/90, training loss: 1.4619728583516851, validation loss: 1.4686139822006226
	     precision: 0.9926
	        recall: 0.9924
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 48/90, training loss: 1.4619723001776983, validation loss: 1.468520998954773
	     precision: 0.9928
	        recall: 0.9926
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 49/90, training loss: 1.4619734732072744, validation loss: 1.4685828685760498
	     precision: 0.9931
	        recall: 0.9930
	            F1: 0.9930
	      accuracy: 0.9931


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 50/90, training loss: 1.4619729788318625, validation loss: 1.468392252922058
	     precision: 0.9931
	        recall: 0.9930
	            F1: 0.9930
	      accuracy: 0.9931


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 51/90, training loss: 1.4619642790955012, validation loss: 1.4687391519546509
	     precision: 0.9927
	        recall: 0.9925
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 52/90, training loss: 1.461966802824789, validation loss: 1.468517780303955
	     precision: 0.9929
	        recall: 0.9929
	            F1: 0.9929
	      accuracy: 0.9930


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 53/90, training loss: 1.461959911053623, validation loss: 1.4687594175338745
	     precision: 0.9927
	        recall: 0.9926
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 54/90, training loss: 1.4619636955037554, validation loss: 1.4686394929885864
	     precision: 0.9924
	        recall: 0.9923
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 55/90, training loss: 1.4619591917310442, validation loss: 1.4684504270553589
	     precision: 0.9929
	        recall: 0.9927
	            F1: 0.9927
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 56/90, training loss: 1.4619487371526039, validation loss: 1.4686020612716675
	     precision: 0.9929
	        recall: 0.9928
	            F1: 0.9928
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 57/90, training loss: 1.4619537121705664, validation loss: 1.468637228012085
	     precision: 0.9927
	        recall: 0.9926
	            F1: 0.9926
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 58/90, training loss: 1.461942107946888, validation loss: 1.4685815572738647
	     precision: 0.9927
	        recall: 0.9926
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 59/90, training loss: 1.4619440877361338, validation loss: 1.4685124158859253
	     precision: 0.9927
	        recall: 0.9926
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 60/90, training loss: 1.4619394553495622, validation loss: 1.4685746431350708
	     precision: 0.9926
	        recall: 0.9925
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…

New learning rate is:  0.0001

Epoch 61/90, training loss: 1.4619463049272485, validation loss: 1.468462347984314
	     precision: 0.9928
	        recall: 0.9927
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 62/90, training loss: 1.4619456501657775, validation loss: 1.468567132949829
	     precision: 0.9927
	        recall: 0.9925
	            F1: 0.9925
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 63/90, training loss: 1.4619488680540627, validation loss: 1.4687273502349854
	     precision: 0.9926
	        recall: 0.9924
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 64/90, training loss: 1.4619455886548007, validation loss: 1.4686399698257446
	     precision: 0.9926
	        recall: 0.9924
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 65/90, training loss: 1.461957247526661, validation loss: 1.4686129093170166
	     precision: 0.9929
	        recall: 0.9927
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 66/90, training loss: 1.4619421254851417, validation loss: 1.4686460494995117
	     precision: 0.9925
	        recall: 0.9925
	            F1: 0.9924
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 67/90, training loss: 1.4619458862967583, validation loss: 1.4686719179153442
	     precision: 0.9927
	        recall: 0.9925
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 68/90, training loss: 1.4619463979562461, validation loss: 1.468431830406189
	     precision: 0.9929
	        recall: 0.9928
	            F1: 0.9928
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 69/90, training loss: 1.461955077104223, validation loss: 1.468503713607788
	     precision: 0.9927
	        recall: 0.9925
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 70/90, training loss: 1.4619442069454234, validation loss: 1.4685190916061401
	     precision: 0.9928
	        recall: 0.9926
	            F1: 0.9926
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 71/90, training loss: 1.4619486245519318, validation loss: 1.4687089920043945
	     precision: 0.9925
	        recall: 0.9923
	            F1: 0.9924
	      accuracy: 0.9925


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 72/90, training loss: 1.4619458270733798, validation loss: 1.468590497970581
	     precision: 0.9926
	        recall: 0.9924
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 73/90, training loss: 1.4619522148103856, validation loss: 1.4687755107879639
	     precision: 0.9924
	        recall: 0.9922
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 74/90, training loss: 1.4619571374677647, validation loss: 1.4685239791870117
	     precision: 0.9930
	        recall: 0.9928
	            F1: 0.9929
	      accuracy: 0.9930


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 75/90, training loss: 1.4619690233202123, validation loss: 1.4685403108596802
	     precision: 0.9928
	        recall: 0.9927
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 76/90, training loss: 1.46195362320841, validation loss: 1.4686039686203003
	     precision: 0.9925
	        recall: 0.9924
	            F1: 0.9924
	      accuracy: 0.9925


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 77/90, training loss: 1.4619427731296402, validation loss: 1.468528389930725
	     precision: 0.9928
	        recall: 0.9926
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 78/90, training loss: 1.4619479710613486, validation loss: 1.4684346914291382
	     precision: 0.9930
	        recall: 0.9928
	            F1: 0.9929
	      accuracy: 0.9930


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 79/90, training loss: 1.461941755402571, validation loss: 1.4685280323028564
	     precision: 0.9928
	        recall: 0.9928
	            F1: 0.9928
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 80/90, training loss: 1.4619548714745527, validation loss: 1.4685192108154297
	     precision: 0.9928
	        recall: 0.9926
	            F1: 0.9927
	      accuracy: 0.9928


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…

New learning rate is:  1e-05

Epoch 81/90, training loss: 1.461948112638266, validation loss: 1.468453049659729
	     precision: 0.9929
	        recall: 0.9928
	            F1: 0.9928
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 82/90, training loss: 1.4619433043608026, validation loss: 1.468884825706482
	     precision: 0.9924
	        recall: 0.9922
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 83/90, training loss: 1.461944129675436, validation loss: 1.4684653282165527
	     precision: 0.9930
	        recall: 0.9929
	            F1: 0.9929
	      accuracy: 0.9930


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 84/90, training loss: 1.4619406586262718, validation loss: 1.468748927116394
	     precision: 0.9927
	        recall: 0.9925
	            F1: 0.9926
	      accuracy: 0.9927


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 85/90, training loss: 1.4619436480089036, validation loss: 1.4686955213546753
	     precision: 0.9926
	        recall: 0.9924
	            F1: 0.9925
	      accuracy: 0.9926


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 86/90, training loss: 1.4619477588230614, validation loss: 1.468808650970459
	     precision: 0.9924
	        recall: 0.9922
	            F1: 0.9923
	      accuracy: 0.9924


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 87/90, training loss: 1.461950037270975, validation loss: 1.468656301498413
	     precision: 0.9929
	        recall: 0.9928
	            F1: 0.9928
	      accuracy: 0.9929


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 88/90, training loss: 1.4619400572420946, validation loss: 1.4685810804367065
	     precision: 0.9925
	        recall: 0.9924
	            F1: 0.9924
	      accuracy: 0.9925


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 89/90, training loss: 1.461947482786199, validation loss: 1.4685943126678467
	     precision: 0.9930
	        recall: 0.9929
	            F1: 0.9929
	      accuracy: 0.9930


HBox(children=(FloatProgress(value=0.0, description='Loss: ', max=469.0, style=ProgressStyle(description_width…


Epoch 90/90, training loss: 1.4619468404794298, validation loss: 1.468474268913269
	     precision: 0.9932
	        recall: 0.9930
	            F1: 0.9931
	      accuracy: 0.9932
[1.5392260045639232, 1.4791650932226608, 1.4729896313600195, 1.4709592025671432, 1.4686612203431282, 1.4674271462060242, 1.4666151914006866, 1.4659180026064549, 1.4653810857455614, 1.4646790678313037, 1.4649907843644685, 1.4641684446253502, 1.4636683212414479, 1.4632786545417964, 1.4631749101793334, 1.463330558368138, 1.462880331824329, 1.4628208179209532, 1.4628090777122644, 1.4626344261901465, 1.462607650360319, 1.4625243362841576, 1.4625405258715534, 1.4625472956374763, 1.4624905293938448, 1.4623249881048954, 1.4625074672800646, 1.4622607653074935, 1.46216947183426, 1.4621133852615031, 1.4620711572134673, 1.4620731756377068, 1.462037784712655, 1.4620120962545562, 1.4620212237718009, 1.462016803369339, 1.4619982623850614, 1.4619991568359993, 1.4619889297465016, 1.4619848987441073, 1.4619933018552216, 1.4619

In [13]:
best_accuracy

0.9931999999999999

In [14]:
accuracies

[0.9781000000000001,
 0.9862,
 0.9865000000000002,
 0.9869,
 0.9869,
 0.9872,
 0.9885999999999999,
 0.9886999999999999,
 0.9907,
 0.9883,
 0.9888000000000001,
 0.991,
 0.9891,
 0.9902000000000001,
 0.9901,
 0.9917,
 0.9911999999999999,
 0.9915,
 0.9923000000000002,
 0.9930000000000001,
 0.9926,
 0.9924,
 0.9924,
 0.9905000000000002,
 0.9921000000000001,
 0.9928000000000001,
 0.9914,
 0.9923000000000002,
 0.992,
 0.9917999999999999,
 0.9921,
 0.9918999999999999,
 0.9924,
 0.9923,
 0.9924,
 0.9926,
 0.992,
 0.9927000000000001,
 0.9928000000000001,
 0.9921000000000001,
 0.9927000000000001,
 0.9926,
 0.9931999999999999,
 0.9929,
 0.9928999999999999,
 0.9929,
 0.9926,
 0.9928000000000001,
 0.9931000000000001,
 0.9931000000000001,
 0.9926999999999999,
 0.993,
 0.9926999999999999,
 0.9924,
 0.9928999999999999,
 0.9928999999999999,
 0.9927999999999999,
 0.9926999999999999,
 0.9927000000000001,
 0.9925999999999998,
 0.9928000000000001,
 0.9926999999999999,
 0.9926,
 0.9926,
 0.9928000000000001,

In [12]:
model

MnistResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=Tru

In [None]:
print(model)