In [None]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from models.densenet import DenseNet3
import util.svhn_loader as svhn

In [None]:
transform_cifar = transforms.Compose([
  transforms.Resize(32),
  transforms.CenterCrop(32),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

datasets = {
  'CIFAR-10': torchvision.datasets.CIFAR10(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar),
  'CIFAR-100': torchvision.datasets.CIFAR100(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar), 
  'SVHN': svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transform_cifar, download=False),
  'dtd': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_cifar),
  'places365': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/places365/", transform=transform_cifar),
  'celebA': torchvision.datasets.CelebA(root='datasets/ood_datasets/', split='test', download=True, transform=transform_cifar),
  'iSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/iSUN", transform=transform_cifar),
  'LSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN", transform=transform_cifar),
  'LSUN_resize': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN_resize", transform=transform_cifar),
}

dataloaders = {
  k: torch.utils.data.DataLoader(v, batch_size=512, shuffle=False) for k,v in datasets.items()
}

ood_dls = ['SVHN', 'LSUN', 'LSUN_resize', 'iSUN', 'dtd', 'places365']

In [None]:
model = DenseNet3(100, 10, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=None, info=None)
checkpoint = torch.load("./checkpoints/CIFAR-10/densenet/checkpoint_100.pth.tar")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model.cuda();

In [None]:
import sys
sys.path.append('./code/')

from metrics import BinaryMetrics, Runner

In [None]:
# msp_metrics = BinaryMetrics()
# msp = Runner(lambda x: F.softmax(model.forward(x), -1).max(-1)[0], msp_metrics, dataloaders['CIFAR-10'])

# dict_msp_metrics = []
# for nm_dl in ood_dls:
#   dict_msp_metrics.append(msp.run(dataloaders[nm_dl]))

# msp_df = pd.DataFrame(dict_msp_metrics, index=ood_dls)
# msp_df.loc['Avg.'] = msp_df.mean()
# msp_df[['FPR@95', 'AUROC', 'AUPR_In']]

In [None]:
energy_metrics = BinaryMetrics()
energy = Runner(lambda x: torch.logsumexp(model.forward(x), -1)/1000.0, energy_metrics, dataloaders['CIFAR-10'])

dict_energy_metrics = {}
for nm_dl in ood_dls:
  dict_energy_metrics[nm_dl] = energy.run(dataloaders[nm_dl])
  
energy_df = pd.DataFrame(dict_energy_metrics)
energy_df['Avg.'] = energy_df.mean(axis=1)
energy_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

In [None]:
info = np.load(f"./cache/CIFAR-10_densenet_feat_stat.npy")
# sorted_idx = np.argsort(info)
# plt.scatter(range(info.shape[0]), info[sorted_idx], s=1)
# plt.show()
# print(info.shape, flush=True)

dice_model = DenseNet3(100, 10, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=90, info=info)
dice_model.load_state_dict(checkpoint['state_dict'])
dice_model.eval()
dice_model.cuda();

In [None]:
dice_metrics = BinaryMetrics()
dice = Runner(lambda x: torch.logsumexp(dice_model.forward(x), -1)/1000.0, dice_metrics, dataloaders['CIFAR-10'])

dict_dice_metrics = {}
for nm_dl in ood_dls:
  dict_dice_metrics[nm_dl] = dice.run(dataloaders[nm_dl])

dice_df = pd.DataFrame(dict_dice_metrics)
dice_df['Avg.'] = dice_df.mean(axis=1)
dice_df.T[['FPR@95', 'AUROC', 'AUPR_In']]