In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import models
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, accuracy_score
from skimage import transform
from torch.utils import data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
from tqdm import tqdm
from colorize.util import reshape

In [None]:
def comparison_model_metrics(model, val_loader, val_name):
    """Get the model metrics (non-balanced AUC, balanced AUC, accuracy) for a comparison model on validation dataset"""
    Y_pred = []
    y_true = []
    
    for X, y in tqdm(val_loader):
        Y_pred.append(model(X).softmax(dim=-1).detach().numpy())
        y_true.append(y.numpy())
        
    Y_pred = np.vstack(Y_pred)
    y_true = np.concatenate(y_true)
    y_pred_labels = np.argmax(Y_pred, axis=1)
    
    np.save(f'data/{val_name}_Y_pred', Y_pred)
    np.save(f'data/{val_name}_y_true', y_true)
    np.save(f'data/{val_name}_y_pred_labels', y_pred_labels)
    
    AUC_macro = roc_auc_score(y_true, Y_pred, multi_class="ovo", average="macro")
    AUC_weighted = roc_auc_score(y_true, Y_pred, multi_class="ovo", average="weighted")
    acc = accuracy_score(y_true, y_pred_labels)
    
    return {'AUC macro': AUC_macro, 'AUC weighted': AUC_weighted, 'Accuracy': acc}

In [None]:
# Load Inception V3 model pretrained on Imagenet 
inceptionV3 = models.inception_v3(pretrained=True)

In [None]:
# Load VGG16 Batch Norm model pretrained on Imagenet
VGG16 = models.vgg16_bn(pretrained=True)

In [None]:
# Folder for Imagenet validation set 
val_folder = 'data/ILSVRC2012_img_val/'

# Dataloader constants 
BS = 40
N_WORKERS = 1

# Normalizer for Imagenet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(val_folder, transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=BS, shuffle=False,
        num_workers=N_WORKERS, pin_memory=True)

In [None]:
model_metrics = comparison_model_metrics(VGG16, val_loader, 'imagenet')

In [None]:
model_metrics