In [None]:
import torch
device = torch.device("cuda:0")
import torch.nn.functional as F
import wandb
import copy
from torchvision.models import resnet18, resnet34, resnet50, efficientnet_b1
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.optim import Adam, SGD, RMSprop
import numpy as np
from large_margin import LargeMarginLoss
import time
import argparse
from data import build_split_datasets
import pickle as pkl
from random_split_generator import FourWayClassSplit
from torch.utils.data import ConcatDataset
import datetime
from experiment import FeatureExtractor
from sklearn.metrics import roc_auc_score, confusion_matrix
import pandas as pd
from large_margin import _get_grad

def test_distance(logits, features, top_k, device):
    eps = 1e-8
    prob = F.softmax(logits, dim=1)

    max_indices = torch.argmax(prob, dim=1)
    pseudo_correct_prob, _ = torch.max(prob, dim=1, keepdim=True)

    pseudo_other_prob = torch.zeros(prob.shape).to(device)
    pseudo_other_prob.copy_(prob)
    pseudo_other_prob[torch.arange(prob.shape[0]),max_indices] = 0.

    # Grabs the next most likely class probabilities
    if top_k > 1:
        topk_prob, _ = pseudo_other_prob.topk(top_k, dim=1)
    else:
        topk_prob, _ = pseudo_other_prob.max(dim=1, keepdim=True)

    pseudo_diff_prob = pseudo_correct_prob - topk_prob

    for i, feature_map in enumerate(features):
        if i == len(features)-1:
            diff_grad = torch.stack([_get_grad(pseudo_diff_prob[:, i], feature_map) for i in range(top_k)],
                                dim=1)
            diff_gradnorm = torch.norm(diff_grad, p=2, dim=2)
            diff_gradnorm.detach_()
            distance = pseudo_diff_prob / (diff_gradnorm + eps)

    return distance


def test_lm_ls(model, top_k, id_loader, ood_loader, device):
    model.eval()
    correct = 0
    anomaly_index = 10
    num_classes = 10
    anom_pred = []
    anom_labels = []
    margin_anom_score_sequence = []
    max_logit_anom_score_sequence = []
    pred_sequence = []
    target_sequence = []
    
    for batch_idx, (id_data, id_target) in enumerate(id_loader):
        id_one_hot = torch.zeros(len(id_target), num_classes).scatter_(1, id_target.unsqueeze(1), 1.).float()
        id_one_hot = id_one_hot.to(device)
        id_data, id_target   = id_data.to(device), id_target.to(device)
        model.clear_features()
        id_data = id_data.to(device)
        id_output, id_features  = model(id_data)
        for id_feature in id_features:
            id_feature.retain_grad()

        ###########################
        # ID Distance Computation #
        ###########################
        raw_id_distance = test_distance(id_output, id_features, top_k, device)
        id_distance = torch.abs(raw_id_distance)

        # Compute number of correctly classified id instances
        id_pred   = id_output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
        _, id_idx = id_output.max(dim=1)
        correct += (id_idx == id_target).sum().item()
        id_anom_pred = [1. if id_pred[i] == anomaly_index else 0. for i in range(len(id_pred))]

        # Concatenate the list (order matters here)
        id_batch_anom_pred = id_anom_pred
        anom_pred = anom_pred + id_batch_anom_pred
        pred_sequence.append(id_pred)
        target_sequence.append(id_target)

        # Compute anomaly scores
        # Use discriminant function to compute id_scores
        # id_distance = torch.abs(id_distance)
        margin_id_scores, _ = torch.max(-1 * id_distance, dim=1)
        max_logit_id_scores, _ = torch.max(id_output, dim=1)
        max_logit_id_scores = -1 * max_logit_id_scores

        # Detaching is important here because it removes these scores from the computational graph
        margin_anom_score_sequence.append(margin_id_scores.detach().cpu())
        max_logit_anom_score_sequence.append(max_logit_id_scores.detach().cpu())
        for i in range(len(id_target)):
            # 0 indicates "nominal"
            anom_labels.append(0.)

    for batch_idx, (ood_data, ood_target) in enumerate(ood_loader):
        ood_target = anomaly_index * torch.ones_like(ood_target)
        ood_one_hot = (1/id_one_hot.shape[1])*torch.ones((len(ood_target), id_one_hot.shape[1])).to(device)
        ood_one_hot = ood_one_hot.to(device)
        ood_data, ood_target = ood_data.to(device), ood_target.to(device)
        model.clear_features()
        ood_output, ood_features = model(ood_data)
        for ood_feature in ood_features:
            ood_feature.retain_grad()

        ############################
        # OOD Distance Computation #
        ############################
        raw_ood_distance = test_distance(ood_output, ood_features, top_k, device)
        ood_distance = torch.abs(raw_ood_distance)

        ood_pred   = ood_output.argmax(dim=1, keepdim=True)
        ood_anom_pred = [1. if ood_pred[i] == anomaly_index else 0. for i in range(len(ood_pred))]
        
        # Concatenate the list (order matters here)
        ood_batch_anom_pred = ood_anom_pred
        anom_pred = anom_pred + ood_batch_anom_pred
        pred_sequence.append(ood_pred)
        target_sequence.append(ood_target)

        margin_ood_scores, _ = torch.max(-1 * ood_distance, dim=1)
        max_logit_ood_scores, _ = torch.max(ood_output, dim=1)
        max_logit_ood_scores = -1 * max_logit_ood_scores

        # Detaching is important here because it removes these scores from computational graph
        margin_anom_score_sequence.append(margin_ood_scores.detach().cpu())
        max_logit_anom_score_sequence.append(max_logit_ood_scores.detach().cpu())
        for i in range(len(ood_target)):
            # 1 indicates "anomaly"
            anom_labels.append(1.)

    margin_anom_scores = torch.hstack(margin_anom_score_sequence).cpu().detach().numpy()
    max_logit_anom_scores = torch.hstack(max_logit_anom_score_sequence).cpu().detach().numpy()
    anom_labels = np.asarray(anom_labels)
    anom_pred = np.asarray(anom_pred)
    pred = torch.vstack(pred_sequence).cpu().detach().numpy()
    pred = np.ndarray.flatten(pred)
    targets = torch.hstack(target_sequence).cpu().detach().numpy()

    margin_AUROC    = roc_auc_score(anom_labels, margin_anom_scores)
    max_logit_AUROC = roc_auc_score(anom_labels, max_logit_anom_scores) 

    accuracy = 100. * correct / len(id_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
        correct, len(id_loader.dataset), accuracy))
    print('Test Set: margin AUROC: {}\n'.format(margin_AUROC))
    print('Test Set: max logit AUROC: {}\n'.format(max_logit_AUROC))

    return accuracy, margin_AUROC, max_logit_AUROC


# TODO: Verify Done
# This works as a test function for our baseline
def test_ce_ls(model, id_loader, ood_loader, device):
    model.eval()
    correct = 0
    anom_labels = []
    anom_score_sequence = []
    with torch.no_grad(): 
        for batch_idx, (id_data, id_target) in enumerate(id_loader):
            id_data, id_target   = id_data.to(device), id_target.to(device)
            id_output, _  = model(id_data)
            
            # Compute number of correctly classified id instances
            id_pred = id_output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            _, id_idx = id_output.max(dim=1)
            correct += (id_idx == id_target).sum().item()

            # Compute anomaly scores
            pos_id_scores, _ = torch.max(id_output, dim=1)
            id_scores = -1 * pos_id_scores
            anom_score_sequence.append(id_scores)
            for i in range(len(id_target)):
                # 0 indicates "nominal"
                anom_labels.append(0.)

        for batch_idx, (ood_data, ood_target) in enumerate(ood_loader):
            ood_data, ood_target = ood_data.to(device), ood_target.to(device)
            ood_output, _ = model(ood_data)
            
            # Compute anomaly scores
            pos_ood_scores, _ = torch.max(ood_output, dim=1)
            ood_scores = -1 * pos_ood_scores
            anom_score_sequence.append(ood_scores)
            for i in range(len(ood_target)):
                # 1 indicates "anomaly"
                anom_labels.append(1.)

    anom_scores = torch.hstack(anom_score_sequence).cpu().numpy()
    anom_labels = np.asarray(anom_labels)
    
    AUROC = roc_auc_score(anom_labels, anom_scores)

    accuracy = 100. * correct / len(id_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
        correct, len(id_loader.dataset), accuracy))
    print('Test Set: AUROC: {}\n'.format(AUROC))

    return accuracy, AUROC


# TODO: Verify Done
def test_ce_ks(model, id_loader, ood_loader, device):
    # For KS, do confusion matrix
    model.eval()
    correct = 0
    anomaly_index = 10
    anom_pred = []
    anom_labels = []
    anom_score_sequence = []
    pred_sequence = []
    target_sequence = []
    with torch.no_grad(): 
        for batch_idx, (id_data, id_target) in enumerate(id_loader):
            id_data, id_target   = id_data.to(device), id_target.to(device)
            id_output, _  = model(id_data)

            # Compute number of correctly classified id instances
            id_pred   = id_output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            _, id_idx = id_output.max(dim=1)
            correct += (id_idx == id_target).sum().item()
            id_anom_pred = [1. if id_pred[i] == anomaly_index else 0. for i in range(len(id_pred))]

            # Concatenate the list (order matters here)
            id_batch_anom_pred = id_anom_pred
            anom_pred = anom_pred + id_batch_anom_pred

            pred_sequence.append(id_pred)
            target_sequence.append(id_target)

            # Compute anomaly scores
            id_scores = id_output[:,anomaly_index]
            anom_score_sequence.append(id_scores)
            for i in range(len(id_target)):
                # 0 indicates "nominal"
                anom_labels.append(0.)

        for batch_idx, (ood_data, ood_target) in enumerate(ood_loader):
            ood_target = anomaly_index * torch.ones_like(ood_target)
            ood_data, ood_target = ood_data.to(device), ood_target.to(device)
            ood_output, _ = model(ood_data)
            ood_pred   = ood_output.argmax(dim=1, keepdim=True)
            ood_anom_pred = [1. if ood_pred[i] == anomaly_index else 0. for i in range(len(ood_pred))]
            ood_batch_anom_pred = ood_anom_pred
            anom_pred = anom_pred + ood_batch_anom_pred
            pred_sequence.append(ood_pred)
            target_sequence.append(ood_target)
            
            # Compute anomaly scores
            ood_scores = ood_output[:,anomaly_index]
            anom_score_sequence.append(ood_scores)
            for i in range(len(ood_target)):
                # 1 indicates "anomaly"
                anom_labels.append(1.)


    anom_scores = torch.hstack(anom_score_sequence).cpu().numpy()
    anom_labels = np.asarray(anom_labels)
    anom_pred = np.asarray(anom_pred)
    pred = torch.vstack(pred_sequence).cpu().numpy()
    pred = np.ndarray.flatten(pred)
    targets = torch.hstack(target_sequence).cpu().numpy()

    # skl_conf_matrix = confusion_matrix(anom_labels, anom_pred)
    # tn_count = skl_conf_matrix[0,0]
    # fp_count = skl_conf_matrix[0,1]
    # fn_count = skl_conf_matrix[1,0]
    # tp_count = skl_conf_matrix[1,1]
    # wandb.log({"Eval True Negatives per Epoch": tn_count}, step=epoch)
    # wandb.log({"Eval False Positives per Epoch": fp_count}, step=epoch)
    # wandb.log({"Eval False Negatives per Epoch": fn_count}, step=epoch)
    # wandb.log({"Eval True Positives per Epoch": tp_count}, step=epoch)
    # detection_conf_matrix = wandb.plot.confusion_matrix(y_true=anom_labels, preds=anom_pred)
    # wandb.log({"Detection Confusion Matrix": detection_conf_matrix}, step=epoch)

    # conf_matrix = wandb.plot.confusion_matrix(y_true=targets, preds=pred)
    # wandb.log({"Confusion Matrix": conf_matrix}, step=epoch)

    AUROC = roc_auc_score(anom_labels, anom_scores)

    accuracy = 100. * correct / len(id_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
        correct, len(id_loader.dataset), accuracy))
    print('Test Set: AUROC: {}\n'.format(AUROC))

    return accuracy, AUROC


# TODO: Verify Correct
def test_lm_ks(model, id_loader, ood_loader, device):
    model.eval()
    correct = 0
    anomaly_index = 10
    anom_pred = []
    anom_labels = []
    anom_score_sequence = []
    pred_sequence = []
    target_sequence = []
    with torch.no_grad(): 
        for batch_idx, (id_data, id_target) in enumerate(id_loader):
            id_data, id_target   = id_data.to(device), id_target.to(device)
            id_output, _  = model(id_data)
            
            # Compute number of correctly classified id instances
            id_pred   = id_output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            _, id_idx = id_output.max(dim=1)
            correct += (id_idx == id_target).sum().item()
            id_anom_pred = [1. if id_pred[i] == anomaly_index else 0. for i in range(len(id_pred))]

            # Concatenate the list (order matters here)
            batch_anom_pred = id_anom_pred ###ood_anom_pred + id_anom_pred
            anom_pred = anom_pred + batch_anom_pred

            pred_sequence.append(id_pred)
            target_sequence.append(id_target)

            id_scores = id_output[:,anomaly_index]
            anom_score_sequence.append(id_scores)
            for i in range(len(id_target)):
                # 0 indicates "nominal"
                anom_labels.append(0.)

        for batch_idx, (ood_data, ood_target) in enumerate(ood_loader):
            ood_target = anomaly_index * torch.ones_like(ood_target)
            ood_data, ood_target = ood_data.to(device), ood_target.to(device)
            ood_output, _ = model(ood_data)

            ood_pred   = ood_output.argmax(dim=1, keepdim=True)
            ood_anom_pred = [1. if ood_pred[i] == anomaly_index else 0. for i in range(len(ood_pred))]
            pred_sequence.append(ood_pred)
            target_sequence.append(ood_target)

            # Compute anomaly scores
            ood_scores = ood_output[:,anomaly_index]
            anom_score_sequence.append(ood_scores)
            for i in range(len(ood_target)):
                # 1 indicates "anomaly"
                anom_labels.append(1.)

    anom_scores = torch.hstack(anom_score_sequence).cpu().numpy()
    anom_labels = np.asarray(anom_labels)
    anom_pred = np.asarray(anom_pred)
    pred = torch.vstack(pred_sequence).cpu().numpy()
    pred = np.ndarray.flatten(pred)
    targets = torch.hstack(target_sequence).cpu().numpy()

    # names = ["nominal", "anomaly"]
    # skl_conf_matrix = confusion_matrix(anom_labels, anom_pred)
    # tn_count = skl_conf_matrix[0,0]
    # fp_count = skl_conf_matrix[0,1]
    # fn_count = skl_conf_matrix[1,0]
    # tp_count = skl_conf_matrix[1,1]
    # wandb.log({"Eval True Negatives per Epoch": tn_count}, step=epoch)
    # wandb.log({"Eval False Positives per Epoch": fp_count}, step=epoch)
    # wandb.log({"Eval False Negatives per Epoch": fn_count}, step=epoch)
    # wandb.log({"Eval True Positives per Epoch": tp_count}, step=epoch)
    # detection_conf_matrix = wandb.plot.confusion_matrix(y_true=anom_labels, preds=anom_pred, class_names=names)
    # wandb.log({"Detection Confusion Matrix": detection_conf_matrix}, step=epoch)

    # conf_matrix = wandb.plot.confusion_matrix(y_true=targets, preds=pred)
    # wandb.log({"Confusion Matrix": conf_matrix}, step=epoch)

    AUROC = roc_auc_score(anom_labels, anom_scores)

    accuracy = 100. * correct / len(id_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
        correct, len(id_loader.dataset), accuracy))
    print('Test Set: AUROC: {}\n'.format(AUROC))

    return accuracy, AUROC


To start testing, fill out form.

In [None]:
BASE_PATH = '/nfs/stak/users/noelt/Documents/Project/noelt_masters_project/models/' 

###########################################################################################
# BEGIN EXPERIMENT FORM
###########################################################################################

#########################
# Experiment Parameters #
#########################

batch_size=256

# "LS" or "KS"
detection_type = "LS"

# "CE" or "margin"
loss = "CE"

# 0 through 4
split = 0

best_epoch = 60

oe_test = True

##########################
# !!! VERY IMPORTANT !!! #
##########################
# If true, trumps loss and detection_type
baseline = True

# "CE" or "margin"; only relevant if baseline is True
baseline_type = "margin"

###########################################################################################
# END EXPERIMENT FORM
###########################################################################################

num_classes = 10
if detection_type == "KS":
    num_classes = 11

if baseline:
    prefix = "baseline" if baseline_type == "CE" else "baseline_margin"
else:
    prefix = "{}_{}".format(loss, detection_type)

directory = 'val_{}_new_class_test/epoch_{}_split_{}.pth'.format(prefix, best_epoch, split)
PATH = BASE_PATH + directory

# Right now, hardcoded to use CIFAR-100, with hardcoded splits
# defined and generated in random_split_generator.py
id_data, ood_data = build_split_datasets(split, oe_test)

id_train_data, id_val_data, id_test_data    = id_data
ood_train_data, ood_val_data, ood_test_data = ood_data

# Constructing Dataloaders
id_train_loader = data.DataLoader(id_train_data, batch_size=batch_size, shuffle=True, drop_last=False)
id_val_loader   = data.DataLoader(id_val_data, batch_size=batch_size, shuffle=True, drop_last=False)
id_test_loader   = data.DataLoader(id_test_data, batch_size=batch_size, shuffle=True, drop_last=False)
ood_train_loader   = data.DataLoader(ood_train_data, batch_size=batch_size, shuffle=True, drop_last=False)
ood_val_loader   = data.DataLoader(ood_val_data, batch_size=batch_size, shuffle=True, drop_last=False)
ood_test_loader   = data.DataLoader(ood_test_data, batch_size=batch_size, shuffle=True, drop_last=False)

model = FeatureExtractor("efficientnet_b1", num_classes)
model.load_state_dict(torch.load(PATH)['model_state_dict'])
model.eval()
model = model.to(device)

top_k=1

if baseline:
    if baseline_type == "CE":
        acc, auc = test_ce_ls(model, id_test_loader, ood_test_loader, device)
    else:
        acc, margin_auc, max_logit_auc = test_lm_ls(model, top_k, id_test_loader, ood_test_loader, device)
if detection_type == "LS" and loss == "margin":
    acc, margin_auc, max_logit_auc = test_lm_ls(model, top_k, id_test_loader, ood_test_loader, device)
elif detection_type == "LS" and loss == "CE":
    acc, auc = test_ce_ls(model, id_test_loader, ood_test_loader, device)
elif detection_type == "KS" and loss == "margin":
    acc, auc = test_lm_ks(model, id_test_loader, ood_test_loader, device)
elif detection_type == "KS" and loss == "CE":
    acc, auc = test_ce_ks(model, id_test_loader, ood_test_loader, device)

print("Accuracy: {}".format(acc))

if detection_type == "LS" and loss == "margin":
    print("Margin AUROC: {}".format(margin_auc))
    print("Max Logit AUROC: {}".format(max_logit_auc))
else:
    print("AUROC: {}".format(auc))

if baseline:
    print("{}baseline experiment on Split {}, from epoch {}, {} test, Done".format("" if baseline_type == "CE" else "margin ", split, best_epoch, "OE" if oe_test else "new class"))
else:
    print("{} + {} Experiment on Split {}, from epoch {}, {} test, Done".format(detection_type, loss, split, best_epoch, "OE" if oe_test else "new class"))