In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import torch
torch.cuda.empty_cache()

import torchvision
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.backends.cudnn as cudnn

from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from timeit import default_timer as timer 

%matplotlib inline
from matplotlib import pyplot as plt

import sys
sys.path.append('/content/drive/MyDrive/NN_Course_Project/project/lib')

from data import loader
from train_test import TrainTestModel
from helper_funcs import MetricsComputation, select_model, save_model
from m_vgg16 import VGG16

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
# Baseline loss function
loss_fn = nn.CrossEntropyLoss()
ranking_criterion = nn.MarginRankingLoss(margin=0.0)
learning_rate = 0.1

In [None]:
# model = select_model()
M_VGG16 = VGG16(num_classes=10) 

optimizer = optim.SGD(M_VGG16.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001, nesterov=False)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[10, 150, 250], gamma=0.1) # Decays learning rate

# dataloaders for datasets
dataloader_train, dataloader_test = loader('CIFAR10', batch_size=128)

start_time = timer()

VGG_C10 = TrainTestModel()
results = VGG_C10.train(model=M_VGG16, 
                        train_dataloader=dataloader_train, 
                        loss_fn= loss_fn, 
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=300, 
                        device=device) 

end_time = timer()
print(f"[INFO] Total training time: {end_time-start_time:.3f} seconds")

print("\n========================================================================================\n")

test_loss, test_acc, binary_labels, confidence_scores = VGG_C10.test(model=M_VGG16, 
                                                                     dataloader=dataloader_test, 
                                                                     loss_fn=loss_fn, 
                                                                     device=device)

VGG_C10_Metrics = MetricsComputation(confidence_scores, binary_labels)
aurc, eaurc, pr_auc, fpr_in_tpr_95 = VGG_C10_Metrics.compute_metrics()

print(f"Area Under Risk Curve (AURC): {aurc}")
print(f"Excessive-AURC (E-AURC): {eaurc}")
print(f"Area Under Precision-Recall Curve (AUPR): {pr_auc}")
print(f"False Positive Rate (FPR) at 95% True Positive Rate (TPR): {fpr_in_tpr_95}\n")

# Save model
save_model('vgg16_cifar10.pth', M_VGG16)