In [1]:
import torch
device = torch.device("cuda:0")
import torch.nn.functional as F
import wandb
import copy
from torchvision.models import resnet18, resnet34, resnet50, efficientnet_b1
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.optim import Adam, SGD, RMSprop
import numpy as np
from large_margin import LargeMarginLoss
import time
import argparse
from data import build_split_datasets
import pickle as pkl
from random_split_generator import FourWayClassSplit
from torch.utils.data import ConcatDataset
import datetime
from experiment import FeatureExtractor
from sklearn.metrics import roc_auc_score
import pandas as pd
from large_margin import _get_grad

def test_distance(logits, features, device):
    eps = 1e-8

    entries = 100
    # print("logits: {}".format(logits))
    # pd_logits = pd.DataFrame(logits.detach().cpu().numpy())
    prob = F.softmax(logits, dim=1)
    # print("prob: {}".format(prob))
    # pd_prob = pd.DataFrame(prob.detach().cpu().numpy())

    max_indices = torch.argmax(prob, dim=1)

    pseudo_correct_prob, _ = torch.max(prob, dim=1, keepdim=True)
    # print("pseudo_correct_prob: {}".format(pseudo_correct_prob))
    # print("pseudo_correct_prob shape: {}".format(pseudo_correct_prob.shape))

    pseudo_other_prob = torch.zeros(prob.shape).to(device)
    pseudo_other_prob.copy_(prob)
    pseudo_other_prob[torch.arange(prob.shape[0]),max_indices] = 0.

    # Grabs the next most likely class probabilities
    if top_k > 1:
        topk_prob, _ = pseudo_other_prob.topk(top_k, dim=1)
    else:
        topk_prob, _ = pseudo_other_prob.max(dim=1, keepdim=True)

    # print("pseudo correct prob shape: {}".format(pseudo_correct_prob.shape))
    # print("topk_prob shape: {}".format(topk_prob.shape))

    pseudo_diff_prob = pseudo_correct_prob - topk_prob

    for i, feature_map in enumerate(features):
        if i == len(features)-1:
            diff_grad = torch.stack([_get_grad(pseudo_diff_prob[:, i], feature_map) for i in range(top_k)],
                                dim=1)
            diff_gradnorm = torch.norm(diff_grad, p=2, dim=2)
            diff_gradnorm.detach_()
            distance = pseudo_diff_prob / (diff_gradnorm + eps)

    return distance


# TODO: Check that ID margins are larger than OOD margins; this is an empirical sanity check
def sanity_test_lm_ls(model, top_k, id_loader, ood_loader, device):
    model.eval()
    correct = 0
    anomaly_index = 10
    num_classes = 10
    anom_pred = []
    anom_labels = []
    anom_score_sequence = []
    pred_sequence = []
    target_sequence = []
    
    for batch_idx, ((id_data, id_target), (ood_data, ood_target)) in enumerate(zip(id_loader, ood_loader)):
        ood_target = anomaly_index * torch.ones_like(ood_target)

        id_one_hot = torch.zeros(len(id_target), num_classes).scatter_(1, id_target.unsqueeze(1), 1.).float()
        ood_one_hot = (1/id_one_hot.shape[1])*torch.ones((len(ood_target), id_one_hot.shape[1])).to(device)
        id_one_hot, ood_one_hot = id_one_hot.to(device), ood_one_hot.to(device)

        print("id_one_hot: {}".format(id_one_hot))
        print("ood_one_hot: {}".format(ood_one_hot))

        id_data, id_target   = id_data.to(device), id_target.to(device)
        ood_data, ood_target = ood_data.to(device), ood_target.to(device)

        model.clear_features()
        id_data = id_data.to(device)
        id_output, id_features  = model(id_data)
        for id_feature in id_features:
            id_feature.retain_grad()
        #id_features = copy.deepcopy(id_features)
        print("entering ID margin comp")

        # lm(id_output, id_one_hot, id_features)
        # raw_id_distance = lm.get_discriminant()
        # id_distance = torch.abs(raw_id_distance)
    
        print("raw id logits: {}".format(id_output[0,:]))

        ###########################
        # ID Distance Computation #
        ###########################
        raw_id_distance = test_distance(id_output, id_features, device)
        id_distance = torch.abs(raw_id_distance)


        # FOR DEBUG
        # pd_id_logits.to_csv('id_logit.csv', header=False, mode='a+')
        # pd_id_prob.to_csv('id_prob.csv', header=False, mode='a+')
        # pd_id_distance = pd.DataFrame(raw_id_distance.detach().cpu().numpy())
        # pd_id_distance.to_csv('id_distance.csv', header=False, mode='a+')


        print("raw id distance: {}".format(id_distance)) #[0,:]))

        model.clear_features()
        ood_output, ood_features = model(ood_data)
        for ood_feature in ood_features:
            ood_feature.retain_grad()
        #ood_features = copy.deepcopy(ood_features)
        print("entering OOD margin comp")

        # lm(ood_output, ood_one_hot, ood_features)
        # raw_ood_distance = lm.get_discriminant()
        # ood_distance = torch.abs(raw_ood_distance)

        print("raw ood logits: {}".format(ood_output[0,:]))

        ############################
        # OOD Distance Computation #
        ############################
        raw_ood_distance = test_distance(ood_output, ood_features, device)
        ood_distance = torch.abs(raw_ood_distance)


        # FOR DEBUG
        # pd_ood_logits.to_csv('ood_logit.csv', header=False, mode='a+')
        # pd_ood_prob.to_csv('ood_prob.csv', header=False, mode='a+')
        # pd_ood_distance = pd.DataFrame(raw_ood_distance.detach().cpu().numpy())
        # pd_ood_distance.to_csv('ood_distance.csv', header=False, mode='a+')


        # print("id distance (should be same as above): {}".format(id_distance))
        print("raw ood distance: {}".format(ood_distance)) #[0,:]))

        ood_pred   = ood_output.argmax(dim=1, keepdim=True)
        ood_anom_pred = [1. if ood_pred[i] == anomaly_index else 0. for i in range(len(ood_pred))]

        # Compute number of correctly classified id instances
        id_pred   = id_output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
        _, id_idx = id_output.max(dim=1)
        correct += (id_idx == id_target).sum().item()
        id_anom_pred = [1. if id_pred[i] == anomaly_index else 0. for i in range(len(id_pred))]

        # Concatenate the list (order matters here)
        batch_anom_pred = ood_anom_pred + id_anom_pred
        anom_pred = anom_pred + batch_anom_pred

        pred_sequence.append(ood_pred)
        pred_sequence.append(id_pred)

        target_sequence.append(ood_target)
        target_sequence.append(id_target)

        # Compute anomaly scores
        # # Use discriminant (distance) function to compute ood_scores
        # ood_distance = torch.abs(ood_distance)
        ood_scores, _ = torch.max(-1 * ood_distance, dim=1)
        ## print(ood_distance[:,1:].shape)
        # ood_scores = torch.norm(ood_distance, p=2, dim=1) #[:,1:], p=2, dim=1)
        # print('ood scores: {}'.format(ood_scores))
        anom_score_sequence.append(ood_scores)
        for i in range(len(ood_target)):
            # 1 indicates "anomaly"
            anom_labels.append(1.)

        # Use discriminant function to compute id_scores
        # id_distance = torch.abs(id_distance)
        id_scores, _ = torch.max(-1 * id_distance, dim=1)
        ## print(id_distance[:,1:].shape)
        # id_scores = torch.norm(id_distance, p=2, dim=1) #[:,1:], p=2, dim=1)
        # print('id scores: {}'.format(id_scores))
        anom_score_sequence.append(id_scores)
        for i in range(len(id_target)):
            # 0 indicates "nominal"
            anom_labels.append(0.)

    anom_scores = torch.hstack(anom_score_sequence).cpu().detach().numpy()
    anom_labels = np.asarray(anom_labels)
    anom_pred = np.asarray(anom_pred)
    pred = torch.vstack(pred_sequence).cpu().detach().numpy()
    pred = np.ndarray.flatten(pred)
    targets = torch.hstack(target_sequence).cpu().detach().numpy()

    AUROC = roc_auc_score(anom_labels, anom_scores)

    accuracy = 100. * correct / len(id_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
        correct, len(id_loader.dataset), accuracy))
    print('Test Set: AUROC: {}\n'.format(AUROC))

    return accuracy, AUROC, id_distance, ood_distance


def split_test_lm_ls(model, top_k, id_loader, ood_loader, device):
    model.eval()
    correct = 0
    anomaly_index = 10
    num_classes = 10
    anom_pred = []
    anom_labels = []
    anom_score_sequence = []
    pred_sequence = []
    target_sequence = []
    
    for batch_idx, (id_data, id_target) in enumerate(id_loader):
        print("id batch {} starting".format(batch_idx))
        id_one_hot = torch.zeros(len(id_target), num_classes).scatter_(1, id_target.unsqueeze(1), 1.).float()
        id_one_hot = id_one_hot.to(device)
        #print("id_one_hot: {}".format(id_one_hot))
        id_data, id_target   = id_data.to(device), id_target.to(device)
        model.clear_features()
        id_data = id_data.to(device)
        id_output, id_features  = model(id_data)
        for id_feature in id_features:
            id_feature.retain_grad()

        ###########################
        # ID Distance Computation #
        ###########################
        raw_id_distance = test_distance(id_output, id_features, device)
        id_distance = torch.abs(raw_id_distance)

        # # FOR DEBUG
        # pd_id_logits = pd.DataFrame(id_output.detach().cpu().numpy())
        # pd_id_logits.to_csv('id_logit.csv', header=False, mode='a+')
        # pd_id_distance = pd.DataFrame(raw_id_distance.detach().cpu().numpy())
        # pd_id_distance.to_csv('id_distance.csv', header=False, mode='a+')

        ## print("raw id distance: {}".format(id_distance)) #[0,:]))

        # Compute number of correctly classified id instances
        id_pred   = id_output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
        _, id_idx = id_output.max(dim=1)
        correct += (id_idx == id_target).sum().item()
        id_anom_pred = [1. if id_pred[i] == anomaly_index else 0. for i in range(len(id_pred))]

        # Concatenate the list (order matters here)
        id_batch_anom_pred = id_anom_pred
        anom_pred = anom_pred + id_batch_anom_pred
        pred_sequence.append(id_pred)
        target_sequence.append(id_target)

        # Compute anomaly scores
        # Use discriminant function to compute id_scores
        # id_distance = torch.abs(id_distance)
        id_scores, _ = torch.max(-1 * id_distance, dim=1)

        # Detaching is important here because it removes these scores from the computational graph
        anom_score_sequence.append(id_scores.detach().cpu())
        for i in range(len(id_target)):
            # 0 indicates "nominal"
            anom_labels.append(0.)

    for batch_idx, (ood_data, ood_target) in enumerate(ood_loader):
        print("ood batch {} starting".format(batch_idx))
        ood_target = anomaly_index * torch.ones_like(ood_target)
        ood_one_hot = (1/id_one_hot.shape[1])*torch.ones((len(ood_target), id_one_hot.shape[1])).to(device)
        ood_one_hot = ood_one_hot.to(device)
        #print("ood_one_hot: {}".format(ood_one_hot))
        ood_data, ood_target = ood_data.to(device), ood_target.to(device)
        model.clear_features()
        ood_output, ood_features = model(ood_data)
        for ood_feature in ood_features:
            ood_feature.retain_grad()
        ##print("entering OOD margin comp")
        ## print("raw ood logits: {}".format(ood_output[0,:]))

        ############################
        # OOD Distance Computation #
        ############################
        raw_ood_distance = test_distance(ood_output, ood_features, device)
        ood_distance = torch.abs(raw_ood_distance)

        # # FOR DEBUG
        # pd_ood_logits = pd.DataFrame(ood_output.detach().cpu().numpy())
        # pd_ood_logits.to_csv('ood_logit.csv', header=False, mode='a+')
        # pd_ood_distance = pd.DataFrame(raw_ood_distance.detach().cpu().numpy())
        # pd_ood_distance.to_csv('ood_distance.csv', header=False, mode='a+')

        ## print("raw ood distance: {}".format(ood_distance)) #[0,:]))
        ood_pred   = ood_output.argmax(dim=1, keepdim=True)
        ood_anom_pred = [1. if ood_pred[i] == anomaly_index else 0. for i in range(len(ood_pred))]
        
        # Concatenate the list (order matters here)
        ood_batch_anom_pred = ood_anom_pred
        anom_pred = anom_pred + ood_batch_anom_pred
        pred_sequence.append(ood_pred)
        target_sequence.append(ood_target)

        ood_scores, _ = torch.max(-1 * ood_distance, dim=1)

        # Detaching is important here because it removes these scores from computational graph
        anom_score_sequence.append(ood_scores.detach().cpu())
        for i in range(len(ood_target)):
            # 1 indicates "anomaly"
            anom_labels.append(1.)

    anom_scores = torch.hstack(anom_score_sequence).cpu().detach().numpy()
    anom_labels = np.asarray(anom_labels)
    anom_pred = np.asarray(anom_pred)
    pred = torch.vstack(pred_sequence).cpu().detach().numpy()
    pred = np.ndarray.flatten(pred)
    targets = torch.hstack(target_sequence).cpu().detach().numpy()

    AUROC = roc_auc_score(anom_labels, anom_scores)

    accuracy = 100. * correct / len(id_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
        correct, len(id_loader.dataset), accuracy))
    print('Test Set: AUROC: {}\n'.format(AUROC))

    return accuracy, AUROC, id_distance, ood_distance


PATH = '/nfs/stak/users/noelt/Documents/Project/noelt_masters_project/models/val_baseline_margin_new_class_test/epoch_259_split_0.pth'
batch_size=256

# Right now, hardcoded to use CIFAR-100, with hardcoded splits
# defined and generated in random_split_generator.py
id_data, ood_data = build_split_datasets(0, False)

id_train_data, id_val_data, id_test_data    = id_data
ood_train_data, ood_val_data, ood_test_data = ood_data

# Constructing Dataloaders
id_train_loader = data.DataLoader(id_train_data, batch_size=batch_size, shuffle=True, drop_last=False)
id_val_loader   = data.DataLoader(id_val_data, batch_size=batch_size, shuffle=True, drop_last=False)
id_test_loader   = data.DataLoader(id_test_data, batch_size=batch_size, shuffle=True, drop_last=False)
ood_train_loader   = data.DataLoader(ood_train_data, batch_size=batch_size, shuffle=True, drop_last=False)
ood_val_loader   = data.DataLoader(ood_val_data, batch_size=batch_size, shuffle=True, drop_last=False)
ood_test_loader   = data.DataLoader(ood_test_data, batch_size=batch_size, shuffle=True, drop_last=False)

model = FeatureExtractor("efficientnet_b1", 10)
model.load_state_dict(torch.load(PATH)['model_state_dict'])
model.eval()
model = model.to(device)

top_k=1

# lm = LargeMarginLoss(
#     gamma=19600,
#     alpha_factor=7,
#     top_k=top_k,
#     dist_norm=2
# )

acc, auc, id_dist, ood_dist = split_test_lm_ls(model, top_k, id_val_loader, ood_val_loader, device)

print("Accuracy: {}".format(acc))
print("AUROC: {}".format(auc))

Files already downloaded and verified
Files already downloaded and verified
id batch 0 starting
id batch 1 starting
id batch 2 starting
id batch 3 starting
ood batch 0 starting
ood batch 1 starting
ood batch 2 starting
ood batch 3 starting
ood batch 4 starting
ood batch 5 starting
ood batch 6 starting
ood batch 7 starting
ood batch 8 starting
ood batch 9 starting
ood batch 10 starting
ood batch 11 starting
ood batch 12 starting
ood batch 13 starting
ood batch 14 starting
ood batch 15 starting
ood batch 16 starting
ood batch 17 starting
ood batch 18 starting
ood batch 19 starting
ood batch 20 starting
ood batch 21 starting
ood batch 22 starting
ood batch 23 starting
Test set: Accuracy: 649/1000 (65%)
Test Set: AUROC: 0.5688960000000001

Accuracy: 64.9
AUROC: 0.5688960000000001
