In [22]:
# ## This cell contains the essential imports you will need – DO NOT CHANGE THE CONTENTS! ##
# # src: MNIST_Handwritten_Digits_STARTER.ipynb
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn

from MNIST_solver import eval_net_on_data
from MNIST_solver import get_max_n_normalized_mean_n_std
from MNIST_solver import get_train_and_test_data_w_batch_size, MNIST_MLP, eval_net_on_data
from MNIST_solver import PerformanceImprover, TrainingStopper, train_network_classification
from MNIST_solver import define_objective_fcn_with_params
from MNIST_solver import get_model_device , get_HW_acceleration_if_available
from MNIST_solver import PerformanceImprover, TrainingStopper

import torch.nn.functional as F
import torchvision.models as models

# Additional optimizer for tuning the hyper-parameters
# src: https://optuna.org
import optuna
import numpy as np
import torchvision.models as models

torch_seed = 11
torch.manual_seed(torch_seed)

<torch._C.Generator at 0x1155e1710>

In [23]:
data_folder = r'./data'
train_raw = datasets.MNIST(root=data_folder, train=True, download = True, transform=None)
test_raw =  datasets.MNIST(root=data_folder, train=False,download = True, transform=None)

max_data_value, img_mean, img_std = get_max_n_normalized_mean_n_std(train_raw)
transform_pipeline = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((img_mean,), (img_std,))
                               ])
# Transformed data (to be used by the network)
train_data = datasets.MNIST(root=data_folder, train=True,  download = False, transform=transform_pipeline )
test_data =  datasets.MNIST(root=data_folder, train=False, download = False, transform=transform_pipeline)

_, img_rows, img_cols = (train_data.data.numpy().shape)
network_input_dim = img_rows * img_cols

In [25]:
BATCH_SIZE = 64 # cannot be changed unless the architecture of resnet is change

train_loader, test_loader = get_train_and_test_data_w_batch_size(BATCH_SIZE, train_data, test_data)

In [4]:
# this code has been adapted from the suggestions made by chatGPT version 3.5, searching for transfer learning.
# Freezing internal parameters results in poor performance (?), so I have decided not to freeze internal layers


class ResNetForMNIST(nn.Module):
    def __init__(self, resnet, internal_params_frozen = False):
        super(ResNetForMNIST, self).__init__()
        
        self.resnet = resnet
        self.internal_params_frozen = internal_params_frozen
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.softmax = F.log_softmax  # Apply softmax along the dimension of classes
        
        if (self.internal_params_frozen):
            freeze_all_layers_but_lastone()
        
        
    def freeze_all_layers_but_lastone(self):
        for param in self.resnet.parameters():
            param.requires_grad = False
            resnet.fc.requires_grad = True

    def forward(self, x):
        logits = self.resnet(x)
        probabilities = self.softmax(logits, dim = 1)
        
        # Use torch.max to get the most likely class
        return probabilities


def create_model():
    # Load the pre-trained ResNet-18 model
    resnet = models.resnet18(weights='ResNet18_Weights.DEFAULT')

    # Modify the final classification layer to output raw scores (logits)
    num_classes = 10
    resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)

    # Create the modified model
    model = ResNetForMNIST(resnet)
    device = get_HW_acceleration_if_available()
    model.to(device)
    return model

In [8]:
def objective_function(trial):
    LEARNING_RATE = trial.suggest_float("lr", 1e-5, 1e-2, log=True)

    # create the model
    model = create_model()
    
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    improver = PerformanceImprover().is_improving
    stopping_algo = TrainingStopper(improver)
    
    model, (_, _), (val_loss, val_acc) = \
        train_network_classification(model, train_loader, test_loader, optimizer, stopping_algo)
    
    best_accuracy = np.max(val_acc)
    
    if best_accuracy > trial.user_attrs.get("best_accuracy", -1.0):
        trial.set_user_attr("best_accuracy", best_accuracy)
        trial.set_user_attr("best_state_dict", model.state_dict())

    return best_accuracy    

In [9]:
study = optuna.create_study(direction='maximize')  # Change 'maximize' to 'minimize' for different objectives

[I 2023-10-17 12:34:55,281] A new study created in memory with name: no-name-785bf80a-8259-4772-bede-df77bca55a6b


In [10]:
study.optimize(objective_function, n_trials=10)

EPOCH: 1
Training Accuracy: 94.2333; Validation Accuracy: 97.5500
EPOCH: 2
Training Accuracy: 98.1300; Validation Accuracy: 98.5900
EPOCH: 3
Training Accuracy: 98.7000; Validation Accuracy: 98.7900
EPOCH: 4
Training Accuracy: 98.8300; Validation Accuracy: 98.9200
EPOCH: 5
Training Accuracy: 99.0650; Validation Accuracy: 98.9000
EPOCH: 6
Training Accuracy: 99.1667; Validation Accuracy: 99.2700
EPOCH: 7
Training Accuracy: 99.3083; Validation Accuracy: 99.2200
EPOCH: 8
Training Accuracy: 99.2833; Validation Accuracy: 98.8000
EPOCH: 9
Training Accuracy: 99.4400; Validation Accuracy: 99.1600
EPOCH: 10
Training Accuracy: 99.4567; Validation Accuracy: 99.0700
EPOCH: 11
Training Accuracy: 99.5517; Validation Accuracy: 98.9900
EPOCH: 12
Training Accuracy: 99.5433; Validation Accuracy: 99.0600
EPOCH: 13


[I 2023-10-17 12:43:29,545] Trial 0 finished with value: 99.3499984741211 and parameters: {'lr': 0.0003501125788634786}. Best is trial 0 with value: 99.3499984741211.


Training Accuracy: 99.6200; Validation Accuracy: 99.3500
EPOCH: 14
EPOCH: 1
Training Accuracy: 87.1950; Validation Accuracy: 96.5400
EPOCH: 2
Training Accuracy: 96.7950; Validation Accuracy: 97.7300
EPOCH: 3
Training Accuracy: 98.0250; Validation Accuracy: 97.9700
EPOCH: 4
Training Accuracy: 98.5533; Validation Accuracy: 98.2900
EPOCH: 5
Training Accuracy: 98.9400; Validation Accuracy: 98.1800
EPOCH: 6
Training Accuracy: 99.1483; Validation Accuracy: 98.4000
EPOCH: 7
Training Accuracy: 99.2900; Validation Accuracy: 98.6400
EPOCH: 8
Training Accuracy: 99.3483; Validation Accuracy: 98.5800
EPOCH: 9
Training Accuracy: 99.5067; Validation Accuracy: 98.3500
EPOCH: 10
Training Accuracy: 99.4950; Validation Accuracy: 98.7800
EPOCH: 11
Training Accuracy: 99.5967; Validation Accuracy: 98.8200
EPOCH: 12
Training Accuracy: 99.6233; Validation Accuracy: 98.7000
EPOCH: 13
Training Accuracy: 99.6300; Validation Accuracy: 98.8000
EPOCH: 14
Training Accuracy: 99.7033; Validation Accuracy: 98.8200
EPOC

[I 2023-10-17 13:02:17,790] Trial 1 finished with value: 99.04999542236328 and parameters: {'lr': 4.549220768103177e-05}. Best is trial 0 with value: 99.3499984741211.


Training Accuracy: 99.9100; Validation Accuracy: 98.8400
EPOCH: 30
EPOCH: 1
Training Accuracy: 89.5483; Validation Accuracy: 97.8900
EPOCH: 2
Training Accuracy: 96.2367; Validation Accuracy: 91.6800
EPOCH: 3
Training Accuracy: 96.8867; Validation Accuracy: 98.2900
EPOCH: 4
Training Accuracy: 97.2367; Validation Accuracy: 98.4100
EPOCH: 5
Training Accuracy: 98.0000; Validation Accuracy: 98.4000
EPOCH: 6
Training Accuracy: 98.3500; Validation Accuracy: 98.6600
EPOCH: 7
Training Accuracy: 98.6467; Validation Accuracy: 98.3000
EPOCH: 8
Training Accuracy: 98.7467; Validation Accuracy: 98.8900
EPOCH: 9
Training Accuracy: 98.0817; Validation Accuracy: 98.3300
EPOCH: 10
Training Accuracy: 98.5150; Validation Accuracy: 98.9600
EPOCH: 11
Training Accuracy: 98.8083; Validation Accuracy: 98.7400
EPOCH: 12
Training Accuracy: 98.9183; Validation Accuracy: 98.7900
EPOCH: 13
Training Accuracy: 99.0983; Validation Accuracy: 98.9400
EPOCH: 14
Training Accuracy: 99.1117; Validation Accuracy: 98.8800
EPOC

[I 2023-10-17 13:25:05,324] Trial 2 finished with value: 99.22000122070312 and parameters: {'lr': 0.009725659259442599}. Best is trial 0 with value: 99.3499984741211.


Training Accuracy: 99.7433; Validation Accuracy: 98.8600
EPOCH: 36
EPOCH: 1
Training Accuracy: 78.7233; Validation Accuracy: 94.1100
EPOCH: 2
Training Accuracy: 94.3100; Validation Accuracy: 96.4000
EPOCH: 3
Training Accuracy: 96.5200; Validation Accuracy: 97.2700
EPOCH: 4
Training Accuracy: 97.6950; Validation Accuracy: 97.6200
EPOCH: 5
Training Accuracy: 98.2017; Validation Accuracy: 97.8700
EPOCH: 6
Training Accuracy: 98.7167; Validation Accuracy: 98.0800
EPOCH: 7
Training Accuracy: 99.1250; Validation Accuracy: 97.9200
EPOCH: 8
Training Accuracy: 99.3000; Validation Accuracy: 97.9600
EPOCH: 9
Training Accuracy: 99.3967; Validation Accuracy: 97.9800
EPOCH: 10
Training Accuracy: 99.5650; Validation Accuracy: 98.3300
EPOCH: 11
Training Accuracy: 99.5800; Validation Accuracy: 98.1100
EPOCH: 12
Training Accuracy: 99.6367; Validation Accuracy: 98.2800
EPOCH: 13
Training Accuracy: 99.6950; Validation Accuracy: 98.3100
EPOCH: 14
Training Accuracy: 99.7683; Validation Accuracy: 98.2100
EPOC

[I 2023-10-17 13:41:45,851] Trial 3 finished with value: 98.47999572753906 and parameters: {'lr': 1.8990360289473992e-05}. Best is trial 0 with value: 99.3499984741211.


Training Accuracy: 99.8933; Validation Accuracy: 98.4100
EPOCH: 27
EPOCH: 1
Training Accuracy: 92.5667; Validation Accuracy: 97.4900
EPOCH: 2
Training Accuracy: 96.9283; Validation Accuracy: 98.4300
EPOCH: 3
Training Accuracy: 97.7117; Validation Accuracy: 98.3600
EPOCH: 4
Training Accuracy: 98.4683; Validation Accuracy: 98.9400
EPOCH: 5
Training Accuracy: 98.6867; Validation Accuracy: 97.4500
EPOCH: 6
Training Accuracy: 98.6433; Validation Accuracy: 97.8400
EPOCH: 7
Training Accuracy: 98.7733; Validation Accuracy: 98.9500
EPOCH: 8
Training Accuracy: 99.0267; Validation Accuracy: 99.0100
EPOCH: 9
Training Accuracy: 99.1283; Validation Accuracy: 98.4700
EPOCH: 10
Training Accuracy: 99.1733; Validation Accuracy: 99.1300
EPOCH: 11
Training Accuracy: 99.0583; Validation Accuracy: 99.0900
EPOCH: 12
Training Accuracy: 99.3933; Validation Accuracy: 99.2500
EPOCH: 13
Training Accuracy: 99.4400; Validation Accuracy: 99.2600
EPOCH: 14
Training Accuracy: 99.3000; Validation Accuracy: 99.2500
EPOC

[I 2023-10-17 13:55:14,960] Trial 4 finished with value: 99.33000183105469 and parameters: {'lr': 0.004716395958581933}. Best is trial 0 with value: 99.3499984741211.


Training Accuracy: 99.6817; Validation Accuracy: 99.2100
EPOCH: 22
EPOCH: 1
Training Accuracy: 94.4350; Validation Accuracy: 98.3600
EPOCH: 2
Training Accuracy: 97.9833; Validation Accuracy: 96.9300
EPOCH: 3
Training Accuracy: 98.6250; Validation Accuracy: 98.6200
EPOCH: 4
Training Accuracy: 98.6100; Validation Accuracy: 97.5200
EPOCH: 5
Training Accuracy: 98.7667; Validation Accuracy: 98.3300
EPOCH: 6
Training Accuracy: 99.0367; Validation Accuracy: 99.0700
EPOCH: 7
Training Accuracy: 99.2150; Validation Accuracy: 99.0100
EPOCH: 8
Training Accuracy: 99.2450; Validation Accuracy: 99.1100
EPOCH: 9
Training Accuracy: 99.4217; Validation Accuracy: 99.1300
EPOCH: 10
Training Accuracy: 99.3233; Validation Accuracy: 97.9900
EPOCH: 11
Training Accuracy: 99.4250; Validation Accuracy: 99.3600
EPOCH: 12
Training Accuracy: 99.5933; Validation Accuracy: 99.2500
EPOCH: 13
Training Accuracy: 99.5450; Validation Accuracy: 99.2500
EPOCH: 14
Training Accuracy: 99.5433; Validation Accuracy: 98.9500
EPOC

[I 2023-10-17 14:09:39,608] Trial 5 finished with value: 99.40999603271484 and parameters: {'lr': 0.0007653110733369356}. Best is trial 5 with value: 99.40999603271484.


Training Accuracy: 99.7950; Validation Accuracy: 99.3700
EPOCH: 23
EPOCH: 1
Training Accuracy: 92.9833; Validation Accuracy: 97.8900
EPOCH: 2
Training Accuracy: 97.9100; Validation Accuracy: 98.3600
EPOCH: 3
Training Accuracy: 98.4700; Validation Accuracy: 98.8900
EPOCH: 4
Training Accuracy: 98.8350; Validation Accuracy: 98.5400
EPOCH: 5
Training Accuracy: 98.9850; Validation Accuracy: 99.1200
EPOCH: 6
Training Accuracy: 99.1150; Validation Accuracy: 98.9900
EPOCH: 7
Training Accuracy: 99.2667; Validation Accuracy: 98.9100
EPOCH: 8
Training Accuracy: 99.3583; Validation Accuracy: 99.2000
EPOCH: 9
Training Accuracy: 99.4117; Validation Accuracy: 98.0900
EPOCH: 10
Training Accuracy: 99.4967; Validation Accuracy: 99.1100
EPOCH: 11
Training Accuracy: 99.5550; Validation Accuracy: 99.0900
EPOCH: 12
Training Accuracy: 99.5783; Validation Accuracy: 98.9000
EPOCH: 13
Training Accuracy: 99.5800; Validation Accuracy: 99.0200
EPOCH: 14
Training Accuracy: 99.7050; Validation Accuracy: 99.0300
EPOC

[I 2023-10-17 14:25:02,530] Trial 6 finished with value: 99.33000183105469 and parameters: {'lr': 0.0001587458543858116}. Best is trial 5 with value: 99.40999603271484.


Training Accuracy: 99.8167; Validation Accuracy: 99.1200
EPOCH: 25
EPOCH: 1
Training Accuracy: 94.9417; Validation Accuracy: 98.4300
EPOCH: 2
Training Accuracy: 97.9150; Validation Accuracy: 98.2700
EPOCH: 3
Training Accuracy: 98.3367; Validation Accuracy: 98.8100
EPOCH: 4
Training Accuracy: 98.8417; Validation Accuracy: 98.9500
EPOCH: 5
Training Accuracy: 98.8867; Validation Accuracy: 98.6300
EPOCH: 6
Training Accuracy: 98.7950; Validation Accuracy: 98.8000
EPOCH: 7
Training Accuracy: 99.1917; Validation Accuracy: 99.1000
EPOCH: 8
Training Accuracy: 99.3083; Validation Accuracy: 98.8800
EPOCH: 9
Training Accuracy: 99.4267; Validation Accuracy: 99.2000
EPOCH: 10
Training Accuracy: 99.4900; Validation Accuracy: 98.9300
EPOCH: 11
Training Accuracy: 99.3033; Validation Accuracy: 99.0900
EPOCH: 12
Training Accuracy: 99.5700; Validation Accuracy: 99.2100
EPOCH: 13
Training Accuracy: 99.5583; Validation Accuracy: 99.2000
EPOCH: 14
Training Accuracy: 99.6767; Validation Accuracy: 99.3100
EPOC

[I 2023-10-17 14:41:41,150] Trial 7 finished with value: 99.44999694824219 and parameters: {'lr': 0.0015069559237179816}. Best is trial 7 with value: 99.44999694824219.


Training Accuracy: 99.7517; Validation Accuracy: 99.4500
EPOCH: 27
EPOCH: 1
Training Accuracy: 93.9167; Validation Accuracy: 98.2000
EPOCH: 2
Training Accuracy: 98.0333; Validation Accuracy: 97.9800
EPOCH: 3
Training Accuracy: 98.5667; Validation Accuracy: 98.5200
EPOCH: 4
Training Accuracy: 98.7450; Validation Accuracy: 97.5700
EPOCH: 5
Training Accuracy: 99.0450; Validation Accuracy: 98.9400
EPOCH: 6
Training Accuracy: 99.1500; Validation Accuracy: 99.0000
EPOCH: 7
Training Accuracy: 99.2467; Validation Accuracy: 98.5500
EPOCH: 8
Training Accuracy: 99.3583; Validation Accuracy: 98.8200
EPOCH: 9
Training Accuracy: 99.3917; Validation Accuracy: 98.8300
EPOCH: 10
Training Accuracy: 99.4783; Validation Accuracy: 99.3300
EPOCH: 11
Training Accuracy: 99.5717; Validation Accuracy: 99.0500
EPOCH: 12
Training Accuracy: 99.6050; Validation Accuracy: 99.2700
EPOCH: 13
Training Accuracy: 99.6183; Validation Accuracy: 99.1600
EPOCH: 14
Training Accuracy: 99.6467; Validation Accuracy: 99.1400
EPOC

[I 2023-10-17 14:59:08,884] Trial 8 finished with value: 99.47999572753906 and parameters: {'lr': 0.00024460936466759193}. Best is trial 8 with value: 99.47999572753906.


Training Accuracy: 99.8717; Validation Accuracy: 98.5200
EPOCH: 28
EPOCH: 1
Training Accuracy: 93.1783; Validation Accuracy: 97.5600
EPOCH: 2
Training Accuracy: 97.9850; Validation Accuracy: 98.2100
EPOCH: 3
Training Accuracy: 98.6033; Validation Accuracy: 98.3900
EPOCH: 4
Training Accuracy: 98.8700; Validation Accuracy: 98.7700
EPOCH: 5
Training Accuracy: 99.0933; Validation Accuracy: 98.7800
EPOCH: 6
Training Accuracy: 99.2367; Validation Accuracy: 98.8000
EPOCH: 7
Training Accuracy: 99.2800; Validation Accuracy: 99.1300
EPOCH: 8
Training Accuracy: 99.3533; Validation Accuracy: 99.1000
EPOCH: 9
Training Accuracy: 99.4700; Validation Accuracy: 98.8500
EPOCH: 10
Training Accuracy: 99.5033; Validation Accuracy: 98.9700
EPOCH: 11
Training Accuracy: 99.5750; Validation Accuracy: 99.0200
EPOCH: 12
Training Accuracy: 99.6583; Validation Accuracy: 98.9700
EPOCH: 13
Training Accuracy: 99.6633; Validation Accuracy: 99.0200
EPOCH: 14
Training Accuracy: 99.6583; Validation Accuracy: 99.2500
EPOC

[I 2023-10-17 15:12:07,714] Trial 9 finished with value: 99.25 and parameters: {'lr': 0.00016589967072105623}. Best is trial 8 with value: 99.47999572753906.


Training Accuracy: 99.8383; Validation Accuracy: 99.2100
EPOCH: 21


In [12]:
model = create_model()

In [14]:
best_trial = study.best_trial
best_trial.value # this should be 99.48%

99.47999572753906

In [15]:
best_trial.params

{'lr': 0.00024460936466759193}

In [16]:
model.load_state_dict(best_trial.user_attrs.get("best_state_dict"))

<All keys matched successfully>

In [18]:
eval_net_on_data(test_loader, model) # expexted 9820; i.e.: 99.48%

(3.1649037593723506, tensor(9948, device='mps:0'))

In [19]:
torch.save(model, 'resnet_MNIST.pt')