In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

We define transform composition which will be used to preprocess each image.
We will just transform each image from its PIL image type to pytorch Tensor and then normalize it so the pixels with
larger value don't have bigger impact than those with the lower value. Same principle as transforming all attributes to the same scale.

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    lambda x: x/255.
])

We are now able to download the data. From the beginning, we will divide the data into two parts: training and testing, in order to evaluate the model's performance on unseen examples. Additionally, we will require a validation set. Typically, the optimal model is determined throughout the epochs using the validation set. Subsequently, we compare the models with varying hyperparameter values or entirely different algorithms, based on their performance on the test set.

In [4]:
full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

test_dataset = datasets.MNIST(root='./data', train = False, download = True, transform = transform)

train_size = int(0.6 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



Now we can define dataloaders for our datasets which will handle loading data into the model.
We can try using different batch sizes on the train data loader which can lead to better performance because of the regularization effect that
minibatched gradients provide. On test and validation loaders we can use the biggest batch size that memory allows us because we are only testing the model not training it.

Also we are going to shuffle train data so we can give the model a better chance to not get stuck in the bad local minima. For the same reason as mentioned above we won't shuffle test and validation sets.

num_workers is set to 2 so the model uses multiple processes to load the data into the GPU device and minimzes the time that GPU is not used. Even though 2 is the usual number of workers the ideal value depends on the data we are working with and the batch size so we can look at it as the hyper parameter for the model training duration.

Last attribute pin_memory is set to True so that we remove one step in transitioning data from the host to the device and directly move the data from the pinned memory. This will also save us some time during the model training.




In [5]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory = True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory = True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory = True)







The "go to" architecture for all visual tasks are convolutional neural networks. Convolutional networks generate high quality visual embeddings that can then be fed into fully connected neural network (classification head). Behind the high quality of image embeddings is the CNNs ability to capture spatial relationships between pixels along with pooling layers which remove the impact of different positions of the digit in the image.

Although MNIST task is relatively simple one and we use a small neural network we will also use batch normalization which will result in more stable training and reduce overfitting. Benefits of the batch normalization are usually observered in really deep networks like Resnet or EffiecentNet and also in small GAN networks.

To show the simplicity of the MNIST dataset task we will also use small fully connected neural network to see how high can we get accuracy with a relatively primitive approach.

On the last layer of both networks we are going to left out the activation function because we will use metrics that expect the logits as the input.



In [6]:
import torch.nn as nn
import torch.nn.functional as F



class ConvModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = 3, stride =  1)
        self.bn1 = nn.BatchNorm2d(num_features = 16)
        self.mp1 = nn.MaxPool2d(kernel_size = 2)
        
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, stride =  1)
        self.bn2 = nn.BatchNorm2d(num_features = 32)
        self.mp2 = nn.MaxPool2d(kernel_size = 2)
        
        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1)
        self.mp3 = nn.MaxPool2d(kernel_size = 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64,100)
        self.fc2 = nn.Linear(100,10)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.mp1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.mp2(x)
        
        x = self.conv3(x)
        x = F.relu(x)
        x = self.mp3(x)
        
        x = self.flatten(x)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x   
    
class FCModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28,100)
        self.fc2 = nn.Linear(100,10)
        
    def forward(self,x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x
        
        
        
        
    

In [10]:
conv_model = ConvModel()
fc_model = FCModel()


Now we can set up the optimizer algorithm which will be used during training. We choose widely used ADAM algorithm, best combination of the AdaGrad and RMSProp algorithms. Learning rate will be adapted with the optimizer so we can start relatively small.


We are going to check if the GPU device is available and if it is (it always should be) we are going to transfer our model weight tensors on it for faster training.

Our loss will be CrossEntropy loss which is usually used for classifcation tasks and pytorch CrossEntropy loss expects logits of the model so in a way now we are setting the last activation layer and it will be softmax which is used in cases when only one class is the right prediction. Reasons behind usage of logits is the way Pytorch CrossEntropyLoss utilizes the possiblity of combining loss function with the softmax layer. 

Also for the same reason we are going to use MulticlassAccuracy class. Usually it would be best to follow other metrics also like Precision, Recall and F-score which is more robust to specific tasks and disbalanced datasets but in this case we can rely on accuracy.


Furthermore we are building the loop for the model training.
Each epoch we go through the whole training set each iteration processing one mini batch, forward pass, backpropagation,
weight updates and metric updates. After each epoch we validate model on the validation set and keep the highest performing one.

In [12]:
from tqdm import tqdm
from copy import deepcopy
from torchmetrics.classification import MulticlassAccuracy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_loop(model):
    
    #optimization algorithm
    LEARNING_RATE = 0.0001
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
    
    #metrics
    criterion = nn.CrossEntropyLoss()
    accuracy = MulticlassAccuracy(10).to(device)
    
    
    NUM_EPOCHS = 50
    
    #metric logging
    train_losses = []
    val_losses = []

    train_accuracies = []
    val_accuracies = []

    best_model_dict = None
    best_val_acc = 0

    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0
        train_accuracy = 0
        for (inputs,labels) in tqdm(train_dataloader):
            optimizer.zero_grad() #zeros out the gradients accumulated from the last batch
            inputs = inputs.to(device) #place tensors on the GPU
            labels = labels.to(device) #

            logits = model(inputs) #caluclates model outputs for the minibatch
            batch_loss = criterion(logits,labels) #calculates loss
            train_loss += batch_loss.detach() #add epoch loss
            train_accuracy += accuracy(logits,labels) #add epoch accuracy

            batch_loss.backward() #calculates gradients for the whole network
            optimizer.step() #weight updates

        train_loss /= len(train_dataloader)
        train_accuracy /= len(train_dataloader)

        with torch.no_grad():#turns of gradient calculations
            model.eval() #switches models layers (e.g. BatchNorm) to evaluation mode 
            val_loss = 0
            val_accuracy = 0
            for (inputs,labels) in tqdm(test_dataloader):
                labels = labels.to(device)
                inputs = inputs.to(device)

                logits = model(inputs) 
                val_loss += criterion(logits,labels).detach() 
                val_accuracy += accuracy(logits,labels)

            val_loss /= len(test_dataloader)
            val_accuracy /= len(test_dataloader)
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy
                best_model_dict = deepcopy(model.state_dict())


        train_losses.append(train_loss)
        val_losses.append(val_loss)

        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)

        print(f"Epoch {epoch+1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
        print(f"Training accuracy: {train_accuracy}, Validation accuracy: {val_accuracy}")
        
    ret_dict = {
        "train_losses":train_losses,
        "val_losses":val_losses,
        "train_accuracies":train_accuracies,
        "val_accuracies":val_accuracies,
        "best_model_dict":best_model_dict,
        "best_val_acc":best_val_acc
    }
    
    return ret_dict
    

    
    

In [13]:
conv_train_dict = train_loop(conv_model)


100%|██████████| 563/563 [00:07<00:00, 77.00it/s]
100%|██████████| 79/79 [00:01<00:00, 59.14it/s]


Epoch 1, Training Loss: 1.0734260082244873, Validation Loss: 0.30852827429771423
Training accuracy: 0.7589402794837952, Validation accuracy: 0.9211059212684631


100%|██████████| 563/563 [00:07<00:00, 71.10it/s]
100%|██████████| 79/79 [00:01<00:00, 60.42it/s]


Epoch 2, Training Loss: 0.2293204963207245, Validation Loss: 0.2056194245815277
Training accuracy: 0.9388031959533691, Validation accuracy: 0.9386386275291443


100%|██████████| 563/563 [00:07<00:00, 78.76it/s]
100%|██████████| 79/79 [00:01<00:00, 54.76it/s]


Epoch 3, Training Loss: 0.14866943657398224, Validation Loss: 0.12511301040649414
Training accuracy: 0.955299973487854, Validation accuracy: 0.9639801979064941


100%|██████████| 563/563 [00:07<00:00, 78.91it/s]
100%|██████████| 79/79 [00:01<00:00, 60.40it/s]


Epoch 4, Training Loss: 0.11733639240264893, Validation Loss: 0.11189431697130203
Training accuracy: 0.9644298553466797, Validation accuracy: 0.9660852551460266


100%|██████████| 563/563 [00:07<00:00, 77.36it/s]
100%|██████████| 79/79 [00:01<00:00, 59.89it/s]


Epoch 5, Training Loss: 0.09853547811508179, Validation Loss: 0.11615388840436935
Training accuracy: 0.9684204459190369, Validation accuracy: 0.9623565673828125


100%|██████████| 563/563 [00:07<00:00, 72.38it/s]
100%|██████████| 79/79 [00:01<00:00, 59.14it/s]


Epoch 6, Training Loss: 0.08661971986293793, Validation Loss: 0.0903484895825386
Training accuracy: 0.9731524586677551, Validation accuracy: 0.9741954803466797


100%|██████████| 563/563 [00:07<00:00, 79.20it/s]
100%|██████████| 79/79 [00:01<00:00, 57.55it/s]


Epoch 7, Training Loss: 0.07767105102539062, Validation Loss: 0.08266489952802658
Training accuracy: 0.9764513969421387, Validation accuracy: 0.9760704040527344


100%|██████████| 563/563 [00:06<00:00, 80.59it/s]
100%|██████████| 79/79 [00:01<00:00, 56.22it/s]


Epoch 8, Training Loss: 0.07035353779792786, Validation Loss: 0.0765974298119545
Training accuracy: 0.9790998697280884, Validation accuracy: 0.9765154123306274


100%|██████████| 563/563 [00:07<00:00, 78.99it/s]
100%|██████████| 79/79 [00:01<00:00, 60.78it/s]


Epoch 9, Training Loss: 0.06423492729663849, Validation Loss: 0.08095762133598328
Training accuracy: 0.9789435267448425, Validation accuracy: 0.9753506779670715


100%|██████████| 563/563 [00:07<00:00, 72.86it/s]
100%|██████████| 79/79 [00:01<00:00, 59.77it/s]


Epoch 10, Training Loss: 0.05943739041686058, Validation Loss: 0.10466526448726654
Training accuracy: 0.9820874333381653, Validation accuracy: 0.9678995013237


100%|██████████| 563/563 [00:07<00:00, 76.67it/s]
100%|██████████| 79/79 [00:01<00:00, 59.11it/s]


Epoch 11, Training Loss: 0.05481399595737457, Validation Loss: 0.09037670493125916
Training accuracy: 0.9824392199516296, Validation accuracy: 0.9713465571403503


100%|██████████| 563/563 [00:07<00:00, 77.34it/s]
100%|██████████| 79/79 [00:01<00:00, 59.44it/s]


Epoch 12, Training Loss: 0.05103837326169014, Validation Loss: 0.07776416093111038
Training accuracy: 0.9840379953384399, Validation accuracy: 0.9771629571914673


100%|██████████| 563/563 [00:07<00:00, 73.65it/s]
100%|██████████| 79/79 [00:01<00:00, 54.75it/s]


Epoch 13, Training Loss: 0.04850563034415245, Validation Loss: 0.1354222446680069
Training accuracy: 0.9843921661376953, Validation accuracy: 0.9566832184791565


100%|██████████| 563/563 [00:07<00:00, 79.93it/s]
100%|██████████| 79/79 [00:01<00:00, 54.13it/s]


Epoch 14, Training Loss: 0.04483005031943321, Validation Loss: 0.11544112116098404
Training accuracy: 0.9855808615684509, Validation accuracy: 0.9645405411720276


100%|██████████| 563/563 [00:07<00:00, 78.32it/s]
100%|██████████| 79/79 [00:01<00:00, 59.65it/s]


Epoch 15, Training Loss: 0.041805632412433624, Validation Loss: 0.06298544257879257
Training accuracy: 0.9864884614944458, Validation accuracy: 0.9813737869262695


100%|██████████| 563/563 [00:07<00:00, 77.08it/s]
100%|██████████| 79/79 [00:01<00:00, 59.64it/s]


Epoch 16, Training Loss: 0.03933555632829666, Validation Loss: 0.08606600761413574
Training accuracy: 0.9874430298805237, Validation accuracy: 0.9736030697822571


100%|██████████| 563/563 [00:07<00:00, 72.80it/s]
100%|██████████| 79/79 [00:01<00:00, 59.75it/s]


Epoch 17, Training Loss: 0.036983802914619446, Validation Loss: 0.09678588062524796
Training accuracy: 0.9887174367904663, Validation accuracy: 0.9695655703544617


100%|██████████| 563/563 [00:07<00:00, 78.19it/s]
100%|██████████| 79/79 [00:01<00:00, 59.85it/s]


Epoch 18, Training Loss: 0.03459439054131508, Validation Loss: 0.06486491858959198
Training accuracy: 0.9887874126434326, Validation accuracy: 0.9802544713020325


100%|██████████| 563/563 [00:07<00:00, 77.03it/s]
100%|██████████| 79/79 [00:01<00:00, 59.15it/s]


Epoch 19, Training Loss: 0.0320957750082016, Validation Loss: 0.0685669332742691
Training accuracy: 0.9897305965423584, Validation accuracy: 0.9793848991394043


100%|██████████| 563/563 [00:07<00:00, 77.17it/s]
100%|██████████| 79/79 [00:01<00:00, 42.45it/s]


Epoch 20, Training Loss: 0.030407816171646118, Validation Loss: 0.07291467487812042
Training accuracy: 0.9893743991851807, Validation accuracy: 0.9776679277420044


100%|██████████| 563/563 [00:07<00:00, 76.28it/s]
100%|██████████| 79/79 [00:01<00:00, 58.55it/s]


Epoch 21, Training Loss: 0.02857227623462677, Validation Loss: 0.13773953914642334
Training accuracy: 0.991454005241394, Validation accuracy: 0.9610252380371094


100%|██████████| 563/563 [00:07<00:00, 77.85it/s]
100%|██████████| 79/79 [00:01<00:00, 54.97it/s]


Epoch 22, Training Loss: 0.02702096849679947, Validation Loss: 0.057618606835603714
Training accuracy: 0.990483820438385, Validation accuracy: 0.9828430414199829


100%|██████████| 563/563 [00:07<00:00, 79.44it/s]
100%|██████████| 79/79 [00:01<00:00, 60.83it/s]


Epoch 23, Training Loss: 0.024753134697675705, Validation Loss: 0.08116889744997025
Training accuracy: 0.9927374720573425, Validation accuracy: 0.9764665365219116


100%|██████████| 563/563 [00:07<00:00, 72.49it/s]
100%|██████████| 79/79 [00:01<00:00, 59.36it/s]


Epoch 24, Training Loss: 0.02306295931339264, Validation Loss: 0.06125537306070328
Training accuracy: 0.9932016730308533, Validation accuracy: 0.9828925132751465


100%|██████████| 563/563 [00:07<00:00, 76.90it/s]
100%|██████████| 79/79 [00:01<00:00, 60.03it/s]


Epoch 25, Training Loss: 0.021854165941476822, Validation Loss: 0.14172761142253876
Training accuracy: 0.9934465885162354, Validation accuracy: 0.9577843546867371


100%|██████████| 563/563 [00:07<00:00, 76.36it/s]
100%|██████████| 79/79 [00:01<00:00, 60.11it/s]


Epoch 26, Training Loss: 0.0203031525015831, Validation Loss: 0.05622456595301628
Training accuracy: 0.9937343001365662, Validation accuracy: 0.9845473170280457


100%|██████████| 563/563 [00:07<00:00, 78.99it/s]
100%|██████████| 79/79 [00:01<00:00, 54.14it/s]


Epoch 27, Training Loss: 0.018720583990216255, Validation Loss: 0.09554556757211685
Training accuracy: 0.9941860437393188, Validation accuracy: 0.9727343320846558


100%|██████████| 563/563 [00:07<00:00, 73.57it/s]
100%|██████████| 79/79 [00:01<00:00, 57.97it/s]


Epoch 28, Training Loss: 0.018155904486775398, Validation Loss: 0.06394782662391663
Training accuracy: 0.9948751330375671, Validation accuracy: 0.9814454317092896


100%|██████████| 563/563 [00:07<00:00, 76.92it/s]
100%|██████████| 79/79 [00:01<00:00, 58.65it/s]


Epoch 29, Training Loss: 0.016293678432703018, Validation Loss: 0.1113656684756279
Training accuracy: 0.9947599768638611, Validation accuracy: 0.9652026295661926


100%|██████████| 563/563 [00:07<00:00, 76.79it/s]
100%|██████████| 79/79 [00:01<00:00, 55.84it/s]


Epoch 30, Training Loss: 0.015324884094297886, Validation Loss: 0.07042928040027618
Training accuracy: 0.9939225316047668, Validation accuracy: 0.9797060489654541


100%|██████████| 563/563 [00:07<00:00, 73.20it/s]
100%|██████████| 79/79 [00:01<00:00, 59.92it/s]


Epoch 31, Training Loss: 0.01425009686499834, Validation Loss: 0.11613054573535919
Training accuracy: 0.9952396154403687, Validation accuracy: 0.9696521162986755


100%|██████████| 563/563 [00:07<00:00, 77.31it/s]
100%|██████████| 79/79 [00:01<00:00, 59.74it/s]


Epoch 32, Training Loss: 0.013191021047532558, Validation Loss: 0.08146017044782639
Training accuracy: 0.9957051277160645, Validation accuracy: 0.9754055738449097


100%|██████████| 563/563 [00:07<00:00, 78.94it/s]
100%|██████████| 79/79 [00:01<00:00, 61.09it/s]


Epoch 33, Training Loss: 0.011953701265156269, Validation Loss: 0.07351865619421005
Training accuracy: 0.9962429404258728, Validation accuracy: 0.9787153601646423


100%|██████████| 563/563 [00:07<00:00, 77.31it/s]
100%|██████████| 79/79 [00:01<00:00, 59.94it/s]


Epoch 34, Training Loss: 0.011743191629648209, Validation Loss: 0.08560046553611755
Training accuracy: 0.9960904121398926, Validation accuracy: 0.9770510792732239


100%|██████████| 563/563 [00:07<00:00, 72.59it/s]
100%|██████████| 79/79 [00:01<00:00, 59.92it/s]


Epoch 35, Training Loss: 0.010821452364325523, Validation Loss: 0.0774853527545929
Training accuracy: 0.9964264035224915, Validation accuracy: 0.9784283638000488


100%|██████████| 563/563 [00:07<00:00, 77.21it/s]
100%|██████████| 79/79 [00:01<00:00, 60.29it/s]


Epoch 36, Training Loss: 0.00961438100785017, Validation Loss: 0.08847031742334366
Training accuracy: 0.9968727827072144, Validation accuracy: 0.9738240838050842


100%|██████████| 563/563 [00:07<00:00, 79.67it/s]
100%|██████████| 79/79 [00:01<00:00, 53.92it/s]


Epoch 37, Training Loss: 0.008619524538516998, Validation Loss: 0.07282202690839767
Training accuracy: 0.9964072108268738, Validation accuracy: 0.9791508316993713


100%|██████████| 563/563 [00:07<00:00, 80.15it/s]
100%|██████████| 79/79 [00:01<00:00, 55.45it/s]


Epoch 38, Training Loss: 0.008012905716896057, Validation Loss: 0.0821664035320282
Training accuracy: 0.9979391098022461, Validation accuracy: 0.9794303774833679


100%|██████████| 563/563 [00:07<00:00, 72.58it/s]
100%|██████████| 79/79 [00:01<00:00, 59.53it/s]


Epoch 39, Training Loss: 0.007759325671941042, Validation Loss: 0.07371427863836288
Training accuracy: 0.9973446130752563, Validation accuracy: 0.9797610640525818


100%|██████████| 563/563 [00:07<00:00, 77.59it/s]
100%|██████████| 79/79 [00:01<00:00, 60.55it/s]


Epoch 40, Training Loss: 0.0069360691122710705, Validation Loss: 0.3226836621761322
Training accuracy: 0.9986631274223328, Validation accuracy: 0.9226508140563965


100%|██████████| 563/563 [00:07<00:00, 77.73it/s]
100%|██████████| 79/79 [00:01<00:00, 60.21it/s]


Epoch 41, Training Loss: 0.005506867542862892, Validation Loss: 0.08460477739572525
Training accuracy: 0.9981319904327393, Validation accuracy: 0.9777036309242249


100%|██████████| 563/563 [00:07<00:00, 73.19it/s]
100%|██████████| 79/79 [00:01<00:00, 55.23it/s]


Epoch 42, Training Loss: 0.005505993962287903, Validation Loss: 0.07772034406661987
Training accuracy: 0.998349130153656, Validation accuracy: 0.9798494577407837


100%|██████████| 563/563 [00:07<00:00, 78.04it/s]
100%|██████████| 79/79 [00:01<00:00, 59.58it/s]


Epoch 43, Training Loss: 0.005223343148827553, Validation Loss: 0.11324374377727509
Training accuracy: 0.9976989030838013, Validation accuracy: 0.9698440432548523


100%|██████████| 563/563 [00:07<00:00, 77.44it/s]
100%|██████████| 79/79 [00:01<00:00, 60.04it/s]


Epoch 44, Training Loss: 0.004195166286081076, Validation Loss: 0.13033536076545715
Training accuracy: 0.9977973103523254, Validation accuracy: 0.967723548412323


100%|██████████| 563/563 [00:07<00:00, 78.36it/s]
100%|██████████| 79/79 [00:01<00:00, 58.19it/s]


Epoch 45, Training Loss: 0.0043319049291312695, Validation Loss: 0.07861307263374329
Training accuracy: 0.9972548484802246, Validation accuracy: 0.9799953699111938


100%|██████████| 563/563 [00:07<00:00, 73.09it/s]
100%|██████████| 79/79 [00:01<00:00, 56.53it/s]


Epoch 46, Training Loss: 0.004103309940546751, Validation Loss: 0.10358893126249313
Training accuracy: 0.9972648024559021, Validation accuracy: 0.9734315872192383


100%|██████████| 563/563 [00:07<00:00, 79.66it/s]
100%|██████████| 79/79 [00:01<00:00, 60.10it/s]


Epoch 47, Training Loss: 0.0038055269978940487, Validation Loss: 0.07678940147161484
Training accuracy: 0.9980406165122986, Validation accuracy: 0.9808520078659058


100%|██████████| 563/563 [00:07<00:00, 76.26it/s]
100%|██████████| 79/79 [00:01<00:00, 58.87it/s]


Epoch 48, Training Loss: 0.0028406032361090183, Validation Loss: 0.07046880573034286
Training accuracy: 0.9986240863800049, Validation accuracy: 0.9825441837310791


100%|██████████| 563/563 [00:07<00:00, 77.14it/s]
100%|██████████| 79/79 [00:01<00:00, 41.07it/s]


Epoch 49, Training Loss: 0.0030905811581760645, Validation Loss: 0.08632968366146088
Training accuracy: 0.9982777237892151, Validation accuracy: 0.979600191116333


100%|██████████| 563/563 [00:07<00:00, 75.85it/s]
100%|██████████| 79/79 [00:01<00:00, 60.08it/s]

Epoch 50, Training Loss: 0.002657767152413726, Validation Loss: 0.08485975116491318
Training accuracy: 0.9987502694129944, Validation accuracy: 0.9808436632156372





In [14]:
fc_train_dict = train_loop(fc_model)

100%|██████████| 563/563 [00:06<00:00, 88.15it/s] 
100%|██████████| 79/79 [00:01<00:00, 61.85it/s]


Epoch 1, Training Loss: 2.293470859527588, Validation Loss: 2.2783117294311523
Training accuracy: 0.18881942331790924, Validation accuracy: 0.21592794358730316


100%|██████████| 563/563 [00:05<00:00, 96.05it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.45it/s]


Epoch 2, Training Loss: 2.25138258934021, Validation Loss: 2.2129342555999756
Training accuracy: 0.3233993649482727, Validation accuracy: 0.43284177780151367


100%|██████████| 563/563 [00:06<00:00, 93.38it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.10it/s]


Epoch 3, Training Loss: 2.162034034729004, Validation Loss: 2.0955398082733154
Training accuracy: 0.4734732508659363, Validation accuracy: 0.5082315802574158


100%|██████████| 563/563 [00:05<00:00, 95.27it/s] 
100%|██████████| 79/79 [00:01<00:00, 60.75it/s]


Epoch 4, Training Loss: 2.0277490615844727, Validation Loss: 1.942260980606079
Training accuracy: 0.5291848182678223, Validation accuracy: 0.5350660085678101


100%|██████████| 563/563 [00:06<00:00, 87.75it/s] 
100%|██████████| 79/79 [00:01<00:00, 60.07it/s]


Epoch 5, Training Loss: 1.8701105117797852, Validation Loss: 1.7774983644485474
Training accuracy: 0.5556100606918335, Validation accuracy: 0.5690993070602417


100%|██████████| 563/563 [00:05<00:00, 94.13it/s] 
100%|██████████| 79/79 [00:01<00:00, 61.85it/s]


Epoch 6, Training Loss: 1.7099254131317139, Validation Loss: 1.619347333908081
Training accuracy: 0.5866497755050659, Validation accuracy: 0.6111337542533875


100%|██████████| 563/563 [00:05<00:00, 95.14it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.72it/s]


Epoch 7, Training Loss: 1.560637354850769, Validation Loss: 1.4768449068069458
Training accuracy: 0.6248822212219238, Validation accuracy: 0.6532741785049438


100%|██████████| 563/563 [00:05<00:00, 95.57it/s] 
100%|██████████| 79/79 [00:01<00:00, 63.14it/s]


Epoch 8, Training Loss: 1.427927851676941, Validation Loss: 1.351967215538025
Training accuracy: 0.6531621217727661, Validation accuracy: 0.6802624464035034


100%|██████████| 563/563 [00:05<00:00, 95.52it/s] 
100%|██████████| 79/79 [00:01<00:00, 45.19it/s]


Epoch 9, Training Loss: 1.3120564222335815, Validation Loss: 1.244136929512024
Training accuracy: 0.6840121746063232, Validation accuracy: 0.704491913318634


100%|██████████| 563/563 [00:05<00:00, 95.69it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.21it/s]


Epoch 10, Training Loss: 1.2119460105895996, Validation Loss: 1.151075005531311
Training accuracy: 0.7094638347625732, Validation accuracy: 0.7286503911018372


100%|██████████| 563/563 [00:05<00:00, 94.51it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.58it/s]


Epoch 11, Training Loss: 1.1252535581588745, Validation Loss: 1.0704665184020996
Training accuracy: 0.7299399375915527, Validation accuracy: 0.747907817363739


100%|██████████| 563/563 [00:05<00:00, 94.70it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.52it/s]


Epoch 12, Training Loss: 1.049545407295227, Validation Loss: 1.0002787113189697
Training accuracy: 0.7469965815544128, Validation accuracy: 0.7604378461837769


100%|██████████| 563/563 [00:05<00:00, 97.62it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.01it/s]


Epoch 13, Training Loss: 0.9837864637374878, Validation Loss: 0.9388925433158875
Training accuracy: 0.7627128958702087, Validation accuracy: 0.7710539698600769


100%|██████████| 563/563 [00:06<00:00, 87.77it/s] 
100%|██████████| 79/79 [00:01<00:00, 63.01it/s]


Epoch 14, Training Loss: 0.9258691668510437, Validation Loss: 0.8843808770179749
Training accuracy: 0.7755095362663269, Validation accuracy: 0.7852200269699097


100%|██████████| 563/563 [00:06<00:00, 93.69it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.63it/s]


Epoch 15, Training Loss: 0.874458372592926, Validation Loss: 0.8362210392951965
Training accuracy: 0.7850282192230225, Validation accuracy: 0.795270562171936


100%|██████████| 563/563 [00:05<00:00, 96.87it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.32it/s]


Epoch 16, Training Loss: 0.8281237483024597, Validation Loss: 0.7929880023002625
Training accuracy: 0.795695960521698, Validation accuracy: 0.8033741116523743


100%|██████████| 563/563 [00:06<00:00, 92.93it/s] 
100%|██████████| 79/79 [00:01<00:00, 60.84it/s]


Epoch 17, Training Loss: 0.7870351672172546, Validation Loss: 0.7540785074234009
Training accuracy: 0.8020195364952087, Validation accuracy: 0.8134621977806091


100%|██████████| 563/563 [00:06<00:00, 86.80it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.16it/s]


Epoch 18, Training Loss: 0.7502537965774536, Validation Loss: 0.7192580699920654
Training accuracy: 0.810985803604126, Validation accuracy: 0.8187189698219299


100%|██████████| 563/563 [00:05<00:00, 97.74it/s] 
100%|██████████| 79/79 [00:01<00:00, 63.11it/s]


Epoch 19, Training Loss: 0.7171247005462646, Validation Loss: 0.688189685344696
Training accuracy: 0.816924512386322, Validation accuracy: 0.8253206610679626


100%|██████████| 563/563 [00:05<00:00, 96.09it/s] 
100%|██████████| 79/79 [00:01<00:00, 60.44it/s]


Epoch 20, Training Loss: 0.6872262358665466, Validation Loss: 0.6597188115119934
Training accuracy: 0.8227065801620483, Validation accuracy: 0.8298153877258301


100%|██████████| 563/563 [00:05<00:00, 98.03it/s] 
100%|██████████| 79/79 [00:01<00:00, 57.61it/s]


Epoch 21, Training Loss: 0.6601128578186035, Validation Loss: 0.6340222358703613
Training accuracy: 0.8280238509178162, Validation accuracy: 0.8362658023834229


100%|██████████| 563/563 [00:05<00:00, 95.56it/s] 
100%|██████████| 79/79 [00:01<00:00, 48.57it/s]


Epoch 22, Training Loss: 0.635560929775238, Validation Loss: 0.6106876134872437
Training accuracy: 0.8336688876152039, Validation accuracy: 0.8414625525474548


100%|██████████| 563/563 [00:06<00:00, 93.17it/s] 
100%|██████████| 79/79 [00:01<00:00, 63.41it/s]


Epoch 23, Training Loss: 0.613442063331604, Validation Loss: 0.5894584059715271
Training accuracy: 0.8399572968482971, Validation accuracy: 0.8458170890808105


100%|██████████| 563/563 [00:06<00:00, 91.13it/s] 
100%|██████████| 79/79 [00:01<00:00, 41.27it/s]


Epoch 24, Training Loss: 0.5929722189903259, Validation Loss: 0.5697527527809143
Training accuracy: 0.8426066040992737, Validation accuracy: 0.8507006168365479


100%|██████████| 563/563 [00:06<00:00, 92.49it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.60it/s]


Epoch 25, Training Loss: 0.5745753645896912, Validation Loss: 0.5524662733078003
Training accuracy: 0.8476651310920715, Validation accuracy: 0.8539204597473145


100%|██████████| 563/563 [00:05<00:00, 95.19it/s] 
100%|██████████| 79/79 [00:01<00:00, 61.84it/s]


Epoch 26, Training Loss: 0.5575478672981262, Validation Loss: 0.5367143750190735
Training accuracy: 0.8511244058609009, Validation accuracy: 0.857216477394104


100%|██████████| 563/563 [00:06<00:00, 87.64it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.04it/s]


Epoch 27, Training Loss: 0.5422417521476746, Validation Loss: 0.5218763947486877
Training accuracy: 0.854038417339325, Validation accuracy: 0.8595245480537415


100%|██████████| 563/563 [00:05<00:00, 96.97it/s] 
100%|██████████| 79/79 [00:01<00:00, 60.46it/s]


Epoch 28, Training Loss: 0.5276028513908386, Validation Loss: 0.5081110596656799
Training accuracy: 0.8575208783149719, Validation accuracy: 0.8636337518692017


100%|██████████| 563/563 [00:05<00:00, 95.52it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.33it/s]


Epoch 29, Training Loss: 0.5146735310554504, Validation Loss: 0.495522677898407
Training accuracy: 0.8622651696205139, Validation accuracy: 0.865337073802948


100%|██████████| 563/563 [00:05<00:00, 95.56it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.49it/s]


Epoch 30, Training Loss: 0.5028505921363831, Validation Loss: 0.48425188660621643
Training accuracy: 0.8623943328857422, Validation accuracy: 0.8680061101913452


100%|██████████| 563/563 [00:06<00:00, 89.66it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.41it/s]


Epoch 31, Training Loss: 0.4915830194950104, Validation Loss: 0.4734219014644623
Training accuracy: 0.865228533744812, Validation accuracy: 0.8716601729393005


100%|██████████| 563/563 [00:05<00:00, 97.29it/s] 
100%|██████████| 79/79 [00:01<00:00, 63.05it/s]


Epoch 32, Training Loss: 0.48153653740882874, Validation Loss: 0.4635663330554962
Training accuracy: 0.8687949776649475, Validation accuracy: 0.8735336065292358


100%|██████████| 563/563 [00:05<00:00, 95.78it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.38it/s]


Epoch 33, Training Loss: 0.4720224440097809, Validation Loss: 0.4544321894645691
Training accuracy: 0.8705248832702637, Validation accuracy: 0.8753834962844849


100%|██████████| 563/563 [00:05<00:00, 97.32it/s] 
100%|██████████| 79/79 [00:01<00:00, 61.56it/s]


Epoch 34, Training Loss: 0.46280428767204285, Validation Loss: 0.4461827874183655
Training accuracy: 0.8739377856254578, Validation accuracy: 0.8780913949012756


100%|██████████| 563/563 [00:06<00:00, 92.18it/s] 
100%|██████████| 79/79 [00:01<00:00, 44.74it/s]


Epoch 35, Training Loss: 0.4548848867416382, Validation Loss: 0.4385248124599457
Training accuracy: 0.8763234615325928, Validation accuracy: 0.8808251023292542


100%|██████████| 563/563 [00:06<00:00, 92.76it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.27it/s]


Epoch 36, Training Loss: 0.447100967168808, Validation Loss: 0.4306046664714813
Training accuracy: 0.8786221742630005, Validation accuracy: 0.8824982643127441


100%|██████████| 563/563 [00:05<00:00, 94.61it/s] 
100%|██████████| 79/79 [00:01<00:00, 61.12it/s]


Epoch 37, Training Loss: 0.4401857852935791, Validation Loss: 0.4240766167640686
Training accuracy: 0.8792920708656311, Validation accuracy: 0.883765697479248


100%|██████████| 563/563 [00:06<00:00, 90.10it/s] 
100%|██████████| 79/79 [00:01<00:00, 58.86it/s]


Epoch 38, Training Loss: 0.4330959916114807, Validation Loss: 0.41792231798171997
Training accuracy: 0.881756603717804, Validation accuracy: 0.8855172991752625


100%|██████████| 563/563 [00:06<00:00, 92.93it/s] 
100%|██████████| 79/79 [00:01<00:00, 61.11it/s]


Epoch 39, Training Loss: 0.42702358961105347, Validation Loss: 0.41160520911216736
Training accuracy: 0.882312536239624, Validation accuracy: 0.8862296938896179


100%|██████████| 563/563 [00:06<00:00, 87.32it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.78it/s]


Epoch 40, Training Loss: 0.4210152328014374, Validation Loss: 0.4057111144065857
Training accuracy: 0.8833910226821899, Validation accuracy: 0.8880340456962585


100%|██████████| 563/563 [00:05<00:00, 94.44it/s] 
100%|██████████| 79/79 [00:01<00:00, 58.81it/s]


Epoch 41, Training Loss: 0.41548261046409607, Validation Loss: 0.40062442421913147
Training accuracy: 0.8834298849105835, Validation accuracy: 0.8894708752632141


100%|██████████| 563/563 [00:05<00:00, 94.54it/s] 
100%|██████████| 79/79 [00:01<00:00, 60.60it/s]


Epoch 42, Training Loss: 0.41057807207107544, Validation Loss: 0.39598214626312256
Training accuracy: 0.884803831577301, Validation accuracy: 0.8911343812942505


100%|██████████| 563/563 [00:05<00:00, 97.22it/s] 
100%|██████████| 79/79 [00:01<00:00, 61.57it/s]


Epoch 43, Training Loss: 0.40552979707717896, Validation Loss: 0.3911377489566803
Training accuracy: 0.8869889974594116, Validation accuracy: 0.8920427560806274


100%|██████████| 563/563 [00:06<00:00, 88.78it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.73it/s]


Epoch 44, Training Loss: 0.40091654658317566, Validation Loss: 0.3867962956428528
Training accuracy: 0.8877066969871521, Validation accuracy: 0.8925841450691223


100%|██████████| 563/563 [00:05<00:00, 95.94it/s] 
100%|██████████| 79/79 [00:01<00:00, 63.50it/s]


Epoch 45, Training Loss: 0.396762877702713, Validation Loss: 0.3828873932361603
Training accuracy: 0.8874467611312866, Validation accuracy: 0.8939771056175232


100%|██████████| 563/563 [00:05<00:00, 96.94it/s] 
100%|██████████| 79/79 [00:01<00:00, 58.63it/s]


Epoch 46, Training Loss: 0.3924742639064789, Validation Loss: 0.37887683510780334
Training accuracy: 0.8887614011764526, Validation accuracy: 0.893664538860321


100%|██████████| 563/563 [00:05<00:00, 94.22it/s] 
100%|██████████| 79/79 [00:01<00:00, 59.35it/s]


Epoch 47, Training Loss: 0.38857316970825195, Validation Loss: 0.37530261278152466
Training accuracy: 0.8900140523910522, Validation accuracy: 0.8956097960472107


100%|██████████| 563/563 [00:06<00:00, 85.51it/s] 
100%|██████████| 79/79 [00:01<00:00, 62.36it/s]


Epoch 48, Training Loss: 0.38491007685661316, Validation Loss: 0.3715350329875946
Training accuracy: 0.890801727771759, Validation accuracy: 0.8965769410133362


100%|██████████| 563/563 [00:06<00:00, 90.48it/s] 
100%|██████████| 79/79 [00:01<00:00, 58.78it/s]


Epoch 49, Training Loss: 0.38139501214027405, Validation Loss: 0.36840835213661194
Training accuracy: 0.8921926617622375, Validation accuracy: 0.8973777294158936


100%|██████████| 563/563 [00:06<00:00, 90.82it/s] 
100%|██████████| 79/79 [00:01<00:00, 60.08it/s]

Epoch 50, Training Loss: 0.37823253870010376, Validation Loss: 0.3651216924190521
Training accuracy: 0.8923975229263306, Validation accuracy: 0.8979558348655701





Now we can load the best models from the state dict.

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


best_model_conv = ConvModel()
best_model_conv.load_state_dict(conv_train_dict["best_model_dict"])
best_model_conv = best_model_conv.to(device)

best_model_fc = FCModel()
best_model_fc.load_state_dict(fc_train_dict["best_model_dict"])
best_model_fc = best_model_fc.to(device)



Now we can test the models.

In [19]:
def test_model_loop(model):
    criterion = nn.CrossEntropyLoss()
    accuracy = MulticlassAccuracy(10).to(device)
    
    loss_test = 0
    accuracy_test = 0
    with torch.no_grad():
        model.eval()
        for (inputs, labels) in tqdm(test_dataloader):
            labels = labels.to(device)
            inputs = inputs.to(device)
            logits = model(inputs)
            loss_test += criterion(logits,labels).detach()
            accuracy_test += accuracy(logits,labels)
            
    loss_test /= len(test_dataloader)
    accuracy_test /= len(test_dataloader)
    
    return  loss_test, accuracy_test






In [20]:
loss_test_fc,accuracy_test_fc = test_model_loop(best_model_fc)

100%|██████████| 79/79 [00:01<00:00, 57.45it/s]


In [21]:
loss_test_conv,accuracy_test_conv = test_model_loop(best_model_conv)


100%|██████████| 79/79 [00:01<00:00, 59.24it/s]


In [23]:
print(loss_test_fc,accuracy_test_fc)
print(loss_test_conv,accuracy_test_conv)

tensor(0.3651, device='cuda:0') tensor(0.8980, device='cuda:0')
tensor(0.0562, device='cuda:0') tensor(0.9845, device='cuda:0')


In [24]:
from prettytable import PrettyTable

    
def print_out_test(loss,acc):
    with open("conv_test_metrics.txt","w") as f:
        f.write(f"Model loss:{loss.item()}, Model accuracy:{acc.item()}")
        
def print_out_metrics(train_values,val_values,metric_name,model_name):
    table = PrettyTable()
    table.field_names = ["Train "+ metric_name, "Validation "+ metric_name]

    for train_value, val_value in zip(train_values, val_values):
        table.add_row([train_value.item(), val_value.item()])
        

    with open(model_name + '_' + metric_name + '.txt', 'w') as f:
        f.write(str(table))
      
    
print_out_test(loss_test_conv,accuracy_test_conv)
print_out_metrics(conv_train_dict["train_losses"],conv_train_dict["val_losses"],"loss","convmodel")
print_out_metrics(conv_train_dict["train_accuracies"], conv_train_dict["val_accuracies"],"accuracy","convmodel")

We are going to save the best model from the training.

In [25]:
torch.save(conv_train_dict["best_model_dict"], "model.pth")