In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
cd "/content/drive/MyDrive/AML_Assignment_2_2021"

In [None]:
ls

In [None]:
"""
Created on Thu Nov 11 17:09:22 2021

@author: marco
"""

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

import matplotlib.pyplot as plt

from ex3_convnet_utils import get_dataset_loaders, weights_init, PrintModelSize, VisualizeFilter
from ex3_convnet_utils import ConvNet
from ex3_convnet_utils import complete_training_and_validation
from ex3_convnet_utils import test_model

import pandas as pd

gridsearch_data_default_path = "ex3_convnet_gridsearch_results_complete.csv"
train_valid_plot_root = "./train_valid_history_plots/"
valid_accuracy_history_root = "./valid_accuracy_history_plots/"
pre_training_filters_root = "./pre_training_filters/"
post_training_filters_root = "./post_training_filters/"

gridsearch_data = pd.DataFrame()

try:
    gridsearch_data = pd.read_csv(gridsearch_data_default_path, sep = ';', index_col = None)
    print("Number of lines in csv: ", len(gridsearch_data))
    
except Exception as E:
    print(E)
    print("No worries! Seems like this is the first run, or maybe the file is actually not there.")


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

index_last_checkpoint = len(gridsearch_data.index) # Will use this to get to the last permutation of hyperparameters
    
    
cols = ['num_epochs', 'batch_size', 'hidden_size', 'learning_rate', 'lr_decay', 'reg', 'norm_layer', 
        'history_loss_train', 'history_loss_validation', 'best_model_valid_accuracy', 'early_stopped_valid_accuracy', 'is_early_stopped', 'train_valid_loss_hist_plot_path',
        'valid_accuracy_plot_path', 'path_pre_train_filter_plots', 'path_post_train_filter_plots', 'best_model_test_accuracy', 'early_stopped_best_accuracy']



''' DATA LOADING '''

data_aug_transforms = []
#data_aug_transforms.append(transforms.ColorJitter(brightness=.5, hue=.05, saturation=.05))
#data_aug_transforms.append(transforms.RandomPerspective(distortion_scale=0.6, p=1.0))
#data_aug_transforms.append(transforms.RandomHorizontalFlip(p=0.3))
#data_aug_transforms.append(transforms.RandomRotation(20))
#data_aug_transforms.append(transforms.RandomInvert(p=0.2))

''' END OF DATA LOADING '''

input_size = 3
num_classes = 10
hidden_size = [128, 512, 512, 512, 512]
num_training = 49000
num_validation = 1000
norm_layer = 'BN'


num_epochs_gs = [5, 10, 30, 50]
batch_size_gs = [200, 400]
learning_rate_gs = [1e-3, 2e-3, 1e-2]
learning_rate_decay_gs = [0.99, 0.95, 0.9]
reg_gs = [0.001, 0.005]

print("Starting training now...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device: %s'%device)

current_iteration = 0

for num_epochs in num_epochs_gs:
    for batch_size in batch_size_gs:
        for learning_rate in learning_rate_gs:
            for learning_rate_decay in learning_rate_decay_gs:
                for reg in reg_gs:
                    

                    if current_iteration >= index_last_checkpoint: # Using the rows of the dataframe, I can go back to the last combination of parameters
                        
                        train_loader, val_loader, test_loader = get_dataset_loaders(data_aug_transforms,
                                                                batch_size,
                                                                num_training, 
                                                                num_validation)

                        
                        model = ConvNet(input_size, hidden_size, num_classes, norm_layer=norm_layer).to(device)
                        model.apply(weights_init)
                        
                        # Model size and filters before training (model size does not change)
                        PrintModelSize(model)
                        
                        full_path_pre_train_filter_plots = str(pre_training_filters_root + str("pre_training_filters_{}_{}_{}_{}_{}.png".format(num_epochs, batch_size, learning_rate,
                                                                                                                  learning_rate_decay, reg)))
                        
                        VisualizeFilter(model, full_path_pre_train_filter_plots, save_to_disk = True, prefix = "Pre-training | ")
                        
                        
                        criterion = nn.CrossEntropyLoss()
                        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=reg)
                        
                        
                        
                        results = complete_training_and_validation(model = model,
                                                        num_epochs = num_epochs,
                                                        train_loader = train_loader,
                                                        val_loader = val_loader,
                                                        device = device,
                                                        learning_rate = learning_rate,
                                                        learning_rate_decay = learning_rate_decay,
                                                        reg = reg,
                                                        batch_size = batch_size)
                        
                        
                        best_model = results[0]
                        early_stopped_model = results[1]
                        loss_train = results[2]
                        loss_val = results[3]
                        best_model_accuracy = results[4]
                        early_stopped_accuracy = results[5]
                        is_early_stopped = results[6]
                        accuracy_val = results[7]
                        
                        best_model_test_accuracy = np.nan
                        early_stopped_test_accuracy = np.nan
                        
                        print("Early stopped flag: ", is_early_stopped)
                        
                        if best_model != None:
                            best_model_test_accuracy = test_model(best_model, test_loader, device)
                            print("Best model test accuracy: ", best_model_test_accuracy)
                            
                        if early_stopped_model != None:
                            early_stopped_test_accuracy = test_model(early_stopped_model, test_loader, device)
                            print("Early stopped test accuracy: ", early_stopped_test_accuracy)
                        
                        
                        
                        
                        full_path_post_train_filter_plots = str(post_training_filters_root + str("post_training_filters_{}_{}_{}_{}_{}.png".format(num_epochs, batch_size, learning_rate,
                                                                                                                  learning_rate_decay, reg)))
                        
                        VisualizeFilter(model, full_path_post_train_filter_plots, save_to_disk = True, prefix = "Post-training | ")
                        

                        
                        
                        plt.figure(2, figsize = (7,5))
                        
                        train_val_loss_hist_title = "Training-Validation loss history > " 
                        train_val_loss_hist_title += str("Epochs: {} | Batch size: {} \n Learning Rate: {} | LR Decay: {} | Reg: {}".format(num_epochs, batch_size, 
                                                                                                                                          learning_rate,
                                                                                                                                          learning_rate_decay,
                                                                                                                                          reg))
                        
                        plt.suptitle(train_val_loss_hist_title)
                        
                        
                        plt.plot(loss_train, 'r', label='Train loss')
                        plt.plot(loss_val, 'g', label='Val loss')
                        plt.xlabel("Epoch")
                        plt.ylabel("Loss")
                        plt.legend()
                        full_path_train_valid_loss_plots = str(train_valid_plot_root + str("train_val_loss_hist_{}_{}_{}_{}_{}.png".format(num_epochs, batch_size, learning_rate,
                                                                                                                  learning_rate_decay, reg)))
                        plt.savefig(full_path_train_valid_loss_plots)
                        
                        plt.close()
                        
                        plt.figure(3, figsize = (7,5))
                        
                        train_val_accuracy_title = "Accuracy history > "
                        train_val_accuracy_title += str("Epochs: {} | Batch size: {} \n Learning Rate: {} | LR Decay: {} | Reg: {}".format(num_epochs, batch_size, 
                                                                                                                                          learning_rate,
                                                                                                                                          learning_rate_decay,
                                                                                                                                          reg))
                        
                        plt.suptitle(train_val_accuracy_title)
                        plt.xlabel("Epoch")
                        plt.ylabel("Accuracy")
                        
                        plt.plot(accuracy_val, 'r', label='Validation accuracy')
                        plt.legend()
                        
                        full_path_valid_accuracy_plots = str(valid_accuracy_history_root + str("valid_accuracy_hist_{}_{}_{}_{}_{}.png".format(num_epochs, batch_size, learning_rate,
                                                                                                                                        learning_rate_decay, reg)))
                        plt.savefig(full_path_valid_accuracy_plots)

                        plt.close()






                        
                        new_row = pd.DataFrame(data = [[num_epochs, batch_size, hidden_size, learning_rate, learning_rate_decay, reg, norm_layer,
                                                        loss_train, loss_val, best_model_accuracy, early_stopped_accuracy, is_early_stopped,
                                                        full_path_train_valid_loss_plots, full_path_valid_accuracy_plots, 
                                                        full_path_pre_train_filter_plots, full_path_post_train_filter_plots,
                                                        best_model_test_accuracy, early_stopped_test_accuracy]], columns = cols)
                        
                        gridsearch_data = gridsearch_data.append(new_row, ignore_index = True)
                        
                        gridsearch_data.to_csv(gridsearch_data_default_path, sep = ';', na_rep = 'nan', index = False)                     
                        print("\n")
                    else:
                        print("Skipping iteration, currently at: ", current_iteration)
                        
                    current_iteration += 1