In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from models import VAE  # Import your VAE architecture
from dataloader import get_dataloader_OOD, get_dataloader_vae
from utils import extract_feature_maps 
import numpy as np
from sklearn.metrics
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.linalg import inv
from scipy.spatial.distance import mahalanobis
import pandas as pd


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define the VAE model
model = VAE()
model.to(device)


VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): ReLU()
    (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): ReLU()
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): Flatten(start_dim=1, end_dim=-1)
    (15): Linear(in_features=512, out_features=256, bias=True)
    (16): ReLU()
    (17): Linear(in_features=256, out_features=128, bias=True)
    (18): ReLU()
    (19): Linear(in_features=12

# GRAM Metrices

## CIFAR  10 In distribution

In [3]:


def calculate_gram_matrix(feature_maps):
    gram_matrix = torch.zeros((len(feature_maps), len(feature_maps)))
    for i in range(len(feature_maps)):
        for j in range(i, len(feature_maps)):  
            gram_matrix[i, j] = torch.dot(torch.from_numpy(feature_maps[i].flatten()), torch.from_numpy(feature_maps[j].flatten()))
            gram_matrix[j, i] = gram_matrix[i, j]  
    return gram_matrix

def train_and_extract_features(model, data_loader, device):
    model.to(device)
    model.train()
    feature_maps = []
    for data, _ in data_loader:
        data = data.to(device)
        reconstruction, mu, logvar, z = model(data)
        current_feature_map = extract_feature_maps(model, data)  # Ensure correct function call

        if isinstance(current_feature_map, list):
            # Convert all NumPy arrays to PyTorch tensors and ensure they are on the correct device
            current_feature_map = [torch.from_numpy(fmap).to(device) if isinstance(fmap, np.ndarray) else fmap for fmap in current_feature_map]

            # Filter out and process only 4D feature maps
            valid_maps = [fmap for fmap in current_feature_map if fmap.dim() == 4]

            if not valid_maps:  # Check if the list is empty
                continue  # Skip if no valid feature maps are found

            # Determine the maximum dimensions to pad to
            max_height = max(fmap.shape[2] for fmap in valid_maps)
            max_width = max(fmap.shape[3] for fmap in valid_maps)

            # Pad all feature maps to the same size
            padded_maps = [F.pad(fmap, (0, max_width - fmap.shape[3], 0, max_height - fmap.shape[2])) for fmap in valid_maps]
            current_feature_map = torch.cat(padded_maps, dim=1)

        z_expanded = z.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, current_feature_map.size(2), current_feature_map.size(3))
        extended_feature_map = torch.cat((current_feature_map, z_expanded), 1)
        feature_maps.append(extended_feature_map)

    gram_matrices = [calculate_gram_matrix(fmap) for fmap in feature_maps]
    return gram_matrices

def calculate_frobenius_distance(A, B):
    return torch.norm(A - B, p='fro').item()

def calculate_ood_scores(mean_gram_in, gram_matrix_ood):
    """Calculate OOD scores based on Frobenius distance from the ID mean Gram matrix."""
    scores = [calculate_frobenius_distance(mean_gram_in, gram) for gram in gram_matrix_ood]
    return scores

In [4]:
model = VAE()
model.load_state_dict(torch.load('../models/CIFAR10.pt'))
cifar10_dataset = get_dataloader_OOD('cifar10', batch_size=128)
data = next(iter(cifar10_dataset)) # get a batch of data
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_in = calculate_gram_matrix(feature_maps)
mean_gram_in = gram_matrix_in.mean()

id_scores_cifar10 = calculate_ood_scores(mean_gram_in, gram_matrix_in)
threshold = np.mean(id_scores_cifar10) + np.std(id_scores_cifar10)

print(mean_gram_in)

Files already downloaded and verified
tensor(2446.2739)


## OOD

In [5]:
#CIFAR-100

cifar100_dataset = get_dataloader_OOD('cifar100', batch_size=128)
data = next(iter(cifar100_dataset))  # get a batch of data
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_ood = calculate_gram_matrix(feature_maps)
g1_scores_cifar100 = calculate_ood_scores(mean_gram_in, gram_matrix_ood)
ood_results = ['OOD' if score > threshold else 'ID' for score in g1_scores_cifar100]
print(g1_scores_cifar100)
print("OOD Detection Results:", ood_results)

#SVHN

svhn_dataset = get_dataloader_OOD('svhn', batch_size=128)
data = next(iter(svhn_dataset))  
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_ood = calculate_gram_matrix(feature_maps)
g1_scores_svhn = calculate_ood_scores(mean_gram_in, gram_matrix_ood)
ood_results = ['OOD' if score > threshold else 'ID' for score in g1_scores_svhn]

print(g1_scores_svhn)
print("OOD Detection Results:", ood_results)

#LSUN
lsun_dataset = get_dataloader_OOD('lsun', batch_size=128)
data = next(iter(lsun_dataset))  
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_ood = calculate_gram_matrix(feature_maps)
g1_scores_lsun = calculate_ood_scores(mean_gram_in, gram_matrix_ood)
ood_results = ['OOD' if score > threshold else 'ID' for score in g1_scores_lsun]

print(g1_scores_lsun)
print("OOD Detection Results:", ood_results)

Files already downloaded and verified
[246294.546875, 65332.64453125, 73011.484375, 89714.6796875, 139127.375, 96324.28125, 112138.984375, 76926.9296875, 173914.9375, 48549.5234375, 222962.765625, 137855.203125, 107111.28125, 134616.0625, 348858.5, 51844.19921875, 207317.453125, 111230.6640625, 69338.1484375, 106253.3359375, 152063.3125, 157391.25, 149111.609375, 87086.5703125, 160108.21875, 81601.484375, 65857.046875, 165738.03125, 108743.7890625, 323287.5, 298840.4375, 79528.21875, 310319.96875, 294001.90625, 81033.421875, 37685.703125, 240133.28125, 69287.546875, 223347.078125, 138095.296875, 61795.79296875, 160303.5625, 185301.046875, 109881.3359375, 122273.2265625, 94725.5, 63714.9296875, 94414.4296875, 211234.015625, 110400.78125, 148991.265625, 73287.5703125, 47660.90625, 121045.9921875, 67392.71875, 87359.0234375, 74304.5, 128934.2890625, 187093.8125, 52767.92578125, 80483.5234375, 145790.265625, 82520.921875, 90101.34375, 162653.421875, 100107.1015625, 58330.2890625, 152220.39



[539602.1875, 410764.34375, 576685.5, 430533.90625, 464351.21875, 424073.15625, 438916.4375, 239169.1875, 341692.25, 363074.5, 290310.40625, 537796.0625, 571774.125, 286822.78125, 531604.875, 263272.46875, 561494.75, 365337.375, 343127.6875, 366207.8125, 235299.96875, 253613.109375, 282788.5625, 542917.8125, 456994.46875, 242005.625, 273975.96875, 250869.3125, 459289.46875, 242025.21875, 441638.5, 345717.9375, 298885.21875, 501669.0, 256629.625, 296528.3125, 343136.875, 439562.96875, 355030.4375, 360453.21875, 275888.53125, 579497.9375, 367836.9375, 345572.71875, 406688.375, 321876.90625, 332484.5, 403827.40625, 536881.0625, 418280.375, 484736.125, 371797.59375, 321210.71875, 243749.8125, 267883.59375, 450983.125, 269181.4375, 486815.9375, 238718.515625, 312708.21875, 570263.9375, 339594.21875, 444943.46875, 333610.8125, 412100.28125, 271623.75, 507924.96875, 341275.4375, 400156.46875, 236392.0625, 295566.6875, 338717.40625, 408102.25, 279772.96875, 336994.84375, 272389.5625, 238480.34

## CIFAR-100 In distribution

In [6]:
# Get the trained model
model.load_state_dict(torch.load('../models/CIFAR100.pt'))
cifar100_dataset = get_dataloader_OOD('cifar100', batch_size=128)
data = next(iter(cifar100_dataset)) # get a batch of data
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_in = calculate_gram_matrix(feature_maps)
mean_gram_in = gram_matrix_in.mean()
id_scores_cifar100 = calculate_ood_scores(mean_gram_in, gram_matrix_in)
threshold = np.mean(id_scores_cifar100) + np.std(id_scores_cifar100)

Files already downloaded and verified


## OOD

In [7]:
#CIFAR-10
cifar10_dataset = get_dataloader_OOD('cifar10', batch_size=128)
data = next(iter(cifar10_dataset))  # get a batch of data
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_ood = calculate_gram_matrix(feature_maps)
cifar10_ood_scores = calculate_ood_scores(mean_gram_in, gram_matrix_ood)

# Determine if each OOD score indicates an OOD instance
ood_results = ['OOD' if score > threshold else 'ID' for score in cifar10_ood_scores]
print("OOD Detection Results:", ood_results)

# SVHN

svhn_dataset = get_dataloader_OOD('svhn', batch_size=128)
data = next(iter(svhn_dataset))  # get a batch of data
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_ood = calculate_gram_matrix(feature_maps)
g2_svhn_scores = calculate_ood_scores(mean_gram_in, gram_matrix_ood)


# Determine if each OOD score indicates an OOD instance
ood_results = ['OOD' if score > threshold else 'ID' for score in g2_svhn_scores]
print("OOD Detection Results SVHN:", ood_results)

# LSUN

lsun_dataset = get_dataloader_OOD('lsun', batch_size=128)
data = next(iter(lsun_dataset))  
feature_maps, _, _ = extract_feature_maps(model, data[0])[:3]
gram_matrix_ood = calculate_gram_matrix(feature_maps)
g2_scores_lsun = calculate_ood_scores(mean_gram_in, gram_matrix_ood)

# Determine if each OOD score indicates an OOD instance
ood_results = ['OOD' if score > threshold else 'ID' for score in g2_scores_lsun]
print("OOD Detection Results LSUN:", ood_results)

Files already downloaded and verified
OOD Detection Results: ['ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'OOD', 'OOD', 'OOD', 'ID', 'OOD', 'OOD', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID', 'OOD', 'ID', 'ID', 'OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'ID']
Using downloaded and verified file: data\test_32x32.mat
OOD Detection Results SVHN: ['OOD', 'ID', 'ID', 'ID', 'ID', 'ID', 'OOD', 'OOD', 'ID', 'ID', 'ID'

# MAHANOLOBUS Distance

In [8]:
model = VAE()
model.to(device)

def compute_mahalanobis_parameters(model, data_loader):
    model.eval()
    features = []
    with torch.no_grad():
        for i, (images, _) in enumerate(data_loader):
            images = images.to(device)
            _, _, _, z = model(images)
            features.append(z.cpu().numpy())

    features = np.concatenate(features, axis=0)
    mean_vector = np.mean(features, axis=0)
    covariance_matrix = np.cov(features, rowvar=False)
    inv_covariance_matrix = inv(covariance_matrix)
    
    return mean_vector, inv_covariance_matrix

def calculate_mahalanobis(model, data_loader, mean_vector, inv_covariance_matrix):
    model.eval()
    distances = []
    with torch.no_grad():
        for i, (images, _) in enumerate(data_loader):
            images = images.to(device)
            _, _, _, z = model(images)
            z = z.cpu().numpy()
            for sample in z:
                dist = mahalanobis(sample, mean_vector, inv_covariance_matrix)
                distances.append(dist)
    return distances



In [9]:
model_state = torch.load('../models/CIFAR10.pt', map_location=device)
model.load_state_dict(model_state)

train_loader, _ = get_dataloader_vae('cifar10', train=True)
_, test_loader = get_dataloader_vae('cifar10', train=False)

train_mean, train_inv_cov = compute_mahalanobis_parameters(model, train_loader)

# Calculate Mahalanobis distance for the test data
test_distances = calculate_mahalanobis(model, test_loader, train_mean, train_inv_cov)

threshold = np.mean(test_distances) + 3 * np.std(test_distances) 


Files already downloaded and verified
Files already downloaded and verified


In [10]:
# CIFAR100 OOD
ood_loader = get_dataloader_OOD('cifar100')
ood_distances_cifar100 = calculate_mahalanobis(model, ood_loader, train_mean, train_inv_cov)
# OOD Detection
mh_ood_cifar100 = [dist > threshold for dist in ood_distances_cifar100]
print(mh_ood_cifar100)

# SVHN

ood_loader = get_dataloader_OOD('svhn')
ood_distances_svhn1 = calculate_mahalanobis(model, ood_loader, train_mean, train_inv_cov)
# OOD Detection
mh_ood_svhn_1 = [dist > threshold for dist in ood_distances_svhn1]
print(mh_ood_svhn_1)

# LSUN
ood_loader = get_dataloader_OOD('lsun')
ood_distances_lsun1 = calculate_mahalanobis(model, ood_loader, train_mean, train_inv_cov)
# OOD Detection
mh_ood_lsun_1 = [dist > threshold for dist in ood_distances_lsun1]
print(mh_ood_lsun_1)

Files already downloaded and verified
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, Fal

## CIFAR100 - IN

In [11]:
model_state = torch.load('../models/CIFAR100.pt', map_location=device)
model.load_state_dict(model_state)

train_loader, _ = get_dataloader_vae('cifar100', train=True)
_, test_loader = get_dataloader_vae('cifar100', train=False)

train_mean, train_inv_cov = compute_mahalanobis_parameters(model, train_loader)

# Calculate Mahalanobis distance for the test data
test_distances_cifar100 = calculate_mahalanobis(model, test_loader, train_mean, train_inv_cov)

threshold = np.percentile(test_distances_cifar100, 95) 

Files already downloaded and verified
Files already downloaded and verified


## OOD

In [12]:
# CIFAR10
ood_loader = get_dataloader_OOD('cifar10')
ood_distances_cifar10 = calculate_mahalanobis(model, ood_loader, train_mean, train_inv_cov)
# OOD Detection
is_ood_cifar10 = [dist > threshold for dist in ood_distances_cifar10]
print(is_ood_cifar10)

# SVHN
ood_loader = get_dataloader_OOD('svhn')
ood_distances_svhn2 = calculate_mahalanobis(model, ood_loader, train_mean, train_inv_cov)
# OOD Detection
mh_ood_svhn_2 = [dist > threshold for dist in ood_distances_svhn2]
print(mh_ood_svhn_2)

#LSUN

ood_loader = get_dataloader_OOD('lsun')
ood_distances_lsun2 = calculate_mahalanobis(model, ood_loader, train_mean, train_inv_cov)
# OOD Detection
mh_ood_lsun_2 = [dist > threshold for dist in ood_distances_lsun2]
print(mh_ood_lsun_2)

Files already downloaded and verified
[False, False, False, False, False, False, False, False, False, True, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False

# Log Likelihood

In [13]:
def calculate_vae_loss(model, data_loader):
    model.eval()
    individual_losses = []

    with torch.no_grad():
        for images, _ in data_loader:
            images = images.to(device)
            reconstruction, mu, logvar, _ = model(images)

            # Reconstruction loss (MSE)
            reconstruction_loss = ((reconstruction - images) ** 2).mean(dim=[1, 2, 3])

            # KL divergence
            kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())


            # Calculate mean losses
            nll = reconstruction_loss + kl_divergence
            individual_losses.extend(nll.cpu().numpy())

    return individual_losses


## In distribution CIFAR-10

In [14]:
_, test_loader = get_dataloader_vae('cifar10', train=False)

# Calculate Losses
individual_id_nll = calculate_vae_loss(model, test_loader)
#Threshold
threshold = np.percentile(individual_id_nll, 95)
print(f"Threshold for OOD Detection: {threshold:.4f}")

Files already downloaded and verified
Threshold for OOD Detection: 3073.4840


## OOD

In [15]:
#CIFAR-100
ood_loader = get_dataloader_OOD('cifar100')
ll_ood_cifar100 = calculate_vae_loss(model, ood_loader)

threshold = np.mean(ll_ood_cifar100)
ll_score_cifar100 = [nll < threshold for nll in ll_ood_cifar100]
if True:
    print(ll_score_cifar100)

#svhn


Files already downloaded and verified
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True,

In [16]:
ood_loader = get_dataloader_OOD('svhn')
ll1_ood_svhn = calculate_vae_loss(model, ood_loader)

ll1_score_svhn = [nll > threshold for nll in ll1_ood_svhn]
print(ll1_score_svhn)

#lsun
ood_loader = get_dataloader_OOD('lsun')
ll1_ood_lsun = calculate_vae_loss(model, ood_loader)

ll1_score_lsun = [nll > threshold for nll in ll1_ood_lsun]
print(ll1_score_lsun)

Using downloaded and verified file: data\test_32x32.mat
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False

## IN - CIFAR-100

In [18]:
_, test_loader = get_dataloader_vae('cifar100', train=False)

# Calculate Losses
individual_id_nll_cifar100 = calculate_vae_loss(model, test_loader)

#Threshold
threshold = np.percentile(individual_id_nll_cifar100, 95)
print(f"Threshold for OOD Detection: {threshold:.4f}")


Files already downloaded and verified
Threshold for OOD Detection: 3156.5848


## OOD

In [19]:
ood_loader = get_dataloader_OOD('cifar10')
ll_ood_cifar10 = calculate_vae_loss(model, ood_loader)

ll_score_cifar10 = [nll > threshold for nll in ll_ood_cifar100]
print(ll_score_cifar10)

#svhn
ood_loader = get_dataloader_OOD('svhn')
ll2_ood_svhn = calculate_vae_loss(model, ood_loader)

ll2_score_svhn = [nll > threshold for nll in ll1_ood_svhn]
print(ll2_score_svhn)

#lsun
ood_loader = get_dataloader_OOD('lsun')
ll2_ood_lsun = calculate_vae_loss(model, ood_loader)

ll2_score_lsun = [nll > threshold for nll in ll1_ood_lsun]
print(ll2_score_lsun)

Files already downloaded and verified
[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, Fals

# Combined Score

In [20]:

def normalize_scores(scores):
    scores = np.array(scores)
    std_dev = np.std(scores)
    if std_dev == 0:
        std_dev = 1e-10
    return (scores - np.mean(scores)) / (np.std(scores)+1e-10)

# CIFAR-10(IN Distribution) 
log_likelihoods_normalized = normalize_scores(individual_id_nll)
gram_scores_normalized = normalize_scores(id_scores_cifar10)
test_distances_normalize = normalize_scores(test_distances)

combined_scores_cifar10 = (log_likelihoods_normalized + gram_scores_normalized[:, None] + test_distances)

threshold = np.mean(combined_scores_cifar10) - 3 * np.std(combined_scores_cifar10)
print(threshold)
ood_inputs_in_cifar10 = combined_scores_cifar10 < threshold

11.271016418498466


In [21]:
#CIFAR-100(OOD)

log_likelihoods_normalized = normalize_scores(ll_ood_cifar100)
gram_scores_normalized = normalize_scores(g1_scores_cifar100)
ood_distances_cifar100_normalized = normalize_scores(ood_distances_cifar100) 

combined_scores_cifar100 = (log_likelihoods_normalized + gram_scores_normalized[:, None] + ood_distances_cifar100_normalized ) / 3

print(combined_scores_cifar10)

ood_inputs_od_cifar100 = combined_scores_cifar100 < threshold


[[14.53807246 14.98272211 14.01707764 ... -8.13844263 -8.29528523
  -8.41350118]
 [16.6555229  17.10017255 16.13452808 ... -6.02099219 -6.17783479
  -6.29605073]
 [15.56281006 16.00745971 15.04181524 ... -7.11370503 -7.27054763
  -7.38876357]
 ...
 [14.09081384 14.53546349 13.56981902 ... -8.58570124 -8.74254384
  -8.86075979]
 [14.31235337 14.75700302 13.79135855 ... -8.36416171 -8.52100431
  -8.63922026]
 [15.07465849 15.51930814 14.55366367 ... -7.60185659 -7.75869919
  -7.87691514]]


In [22]:

# CIFAR-10 (In-Distribution)
log_likelihoods_normalized_cifar10 = normalize_scores(individual_id_nll)
gram_scores_normalized_cifar10 = normalize_scores(id_scores_cifar10)
test_distances_normalize_cifar10 = normalize_scores(test_distances)

combined_scores_cifar10 = (log_likelihoods_normalized_cifar10 + gram_scores_normalized_cifar10[:, None] + test_distances_normalize_cifar10)

# SVHN (Out-of-Distribution)
log_likelihoods_normalized_svhn = normalize_scores(ll1_score_svhn)
gram_scores_normalized_svhn = normalize_scores(g1_scores_svhn)
ood_distances_svhn1_normalized = normalize_scores(ood_distances_svhn1)

combined_scores_svhn = (log_likelihoods_normalized_svhn + gram_scores_normalized_svhn[:, None] + ood_distances_svhn1_normalized) / 3

# LSUN (Out-of-Distribution)
log_likelihoods_normalized_lsun = normalize_scores(ll1_score_lsun)
gram_scores_normalized_lsun = normalize_scores(g1_scores_lsun)
ood_distances_lsun1_normalized = normalize_scores(ood_distances_lsun1)

combined_scores_lsun = (log_likelihoods_normalized_lsun + gram_scores_normalized_lsun[:, None] + ood_distances_lsun1_normalized) / 3

# Reshape the arrays to have the same number of columns
combined_scores_cifar10 = combined_scores_cifar10.reshape(-1, 1)
combined_scores_svhn = combined_scores_svhn.reshape(-1, 1)
combined_scores_lsun = combined_scores_lsun.reshape(-1, 1)

# Combine the scores from both datasets
all_scores = np.concatenate([combined_scores_cifar10, combined_scores_svhn], axis=0)

# Create binary labels
labels = np.concatenate([np.ones_like(combined_scores_cifar10), np.zeros_like(combined_scores_svhn)])

In [30]:
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
# Combine the scores from both datasets
all_scores = np.concatenate([combined_scores_cifar10, combined_scores_svhn], axis=0)

# Create binary labels
labels = np.concatenate([np.ones_like(combined_scores_cifar10), np.zeros_like(combined_scores_svhn)])

auroc = roc_auc_score(labels, all_scores)
aupr = average_precision_score(labels, all_scores)

fpr, tpr, thresholds = roc_curve(labels, all_scores)
fpr_at_tpr95 = fpr[np.argmax(tpr >= 0.95)]

print("SVHN")
print("AUROC", auroc)
print("aupr", aupr)
print("fpr_at_tpr95", fpr_at_tpr95)

SVHN
AUROC 0.4840543000329692
aupr 0.47070233619177665
fpr_at_tpr95 1.0


In [29]:


all_scores = np.concatenate([combined_scores_cifar10, combined_scores_lsun], axis=0)
# Create binary labels
labels = np.concatenate([np.ones_like(combined_scores_cifar10), np.zeros_like(combined_scores_lsun)])
auroc = roc_auc_score(labels, all_scores)
aupr = average_precision_score(labels, all_scores)

fpr, tpr, thresholds = roc_curve(labels, all_scores)
fpr_at_tpr95 = fpr[np.argmax(tpr >= 0.95)]
print("LSUN")
print("AUROC", auroc)
print("aupr", aupr)
print("fpr_at_tpr95", fpr_at_tpr95)

LSUN
AUROC 0.481160849818115
aupr 0.6290770495422051
fpr_at_tpr95 0.9889296875


In [25]:
# CIFAR-100(IN Distribution) 
log_likelihoods_normalized = normalize_scores(individual_id_nll_cifar100)
gram_scores_normalized = normalize_scores(id_scores_cifar100)
test_distances_normalize = normalize_scores(test_distances_cifar100)

combined_scores_cifar100 = (log_likelihoods_normalized + gram_scores_normalized[:, None] + test_distances)

threshold = np.mean(combined_scores_cifar100) - 3 * np.std(combined_scores_cifar100)
print(threshold)
ood_inputs_in_cifar100 = combined_scores_cifar100 < threshold

11.281657656278838


In [32]:

# SVHN (Out-of-Distribution)
log_likelihoods_normalized_svhn = normalize_scores(ll2_score_svhn)
gram_scores_normalized_svhn = normalize_scores(g2_svhn_scores)
ood_distances_svhn1_normalized = normalize_scores(ood_distances_svhn2)

combined_scores_svhn = (log_likelihoods_normalized_svhn + gram_scores_normalized_svhn[:, None] + ood_distances_svhn1_normalized) / 3

# Reshape the arrays to have the same number of columns
combined_scores_cifar100 = combined_scores_cifar100.reshape(-1, 1)
combined_scores_svhn = combined_scores_svhn.reshape(-1, 1)

# Combine the scores from both datasets
all_scores = np.concatenate([combined_scores_cifar100, combined_scores_svhn], axis=0)

# Create binary labels
labels = np.concatenate([np.ones_like(combined_scores_cifar100), np.zeros_like(combined_scores_svhn)])

all_scores = np.concatenate([combined_scores_cifar100, combined_scores_svhn], axis=0)
# Create binary labels
labels = np.concatenate([np.ones_like(combined_scores_cifar100), np.zeros_like(combined_scores_svhn)])
auroc = roc_auc_score(labels, all_scores)
aupr = average_precision_score(labels, all_scores)

fpr, tpr, thresholds = roc_curve(labels, all_scores)
fpr_at_tpr95 = fpr[np.argmax(tpr >= 0.95)]

print("SVHN")
print("AUROC", auroc)
print("aupr", aupr)
print("fpr_at_tpr95", fpr_at_tpr95)

AUROC 0.4840543000329692
aupr 0.47070233619177665
fpr_at_tpr95 1.0


In [28]:
# LSUN (Out-of-Distribution)
log_likelihoods_normalized_lsun = normalize_scores(ll2_score_lsun)
gram_scores_normalized_lsun = normalize_scores(g2_scores_lsun)
ood_distances_lsun2_normalized = normalize_scores(ood_distances_lsun2)

combined_scores_lsun = (log_likelihoods_normalized_lsun + gram_scores_normalized_lsun[:, None] + ood_distances_lsun2_normalized) / 3

# Reshape the arrays to have the same number of columns
combined_scores_cifar100 = combined_scores_cifar100.reshape(-1, 1)
combined_scores_lsun = combined_scores_lsun.reshape(-1, 1)

# Combine the scores from both datasets
all_scores = np.concatenate([combined_scores_cifar100, combined_scores_lsun], axis=0)

# Create binary labels
labels = np.concatenate([np.ones_like(combined_scores_cifar100), np.zeros_like(combined_scores_lsun)])

auroc = roc_auc_score(labels, all_scores)
aupr = average_precision_score(labels, all_scores)

fpr, tpr, thresholds = roc_curve(labels, all_scores)
fpr_at_tpr95 = fpr[np.argmax(tpr >= 0.95)]

print("AUROC", auroc)
print("aupr", aupr)
print("fpr_at_tpr95", fpr_at_tpr95)

AUROC 0.481160849818115
aupr 0.6290770495422051
fpr_at_tpr95 0.9889296875
