# DICE
[DICE: Leveraging Sparsification for Out-of-Distribution Detection](https://arxiv.org/abs/2111.09805)

In [1]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from models.densenet import DenseNet3
import util.svhn_loader as svhn

In [2]:
transform_cifar = transforms.Compose([
  transforms.Resize(32),
  transforms.CenterCrop(32),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

datasets = {
  'CIFAR-10': torchvision.datasets.CIFAR10(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar),
  'CIFAR-100': torchvision.datasets.CIFAR100(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar), 
  'SVHN': svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transform_cifar, download=False),
  'dtd': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_cifar),
  'places365': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/places365/", transform=transform_cifar),
  'celebA': torchvision.datasets.CelebA(root='datasets/ood_datasets/', split='test', download=True, transform=transform_cifar),
  'iSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/iSUN", transform=transform_cifar),
  'LSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN", transform=transform_cifar),
  'LSUN_resize': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN_resize", transform=transform_cifar),
}

dataloaders = {
  k: torch.utils.data.DataLoader(v, batch_size=512, shuffle=False) for k,v in datasets.items()
}

ood_dls = ['SVHN', 'LSUN', 'LSUN_resize', 'iSUN', 'dtd', 'places365']

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
import sys
sys.path.append('./code/')

from metrics import BinaryMetrics, Runner

device = 'cuda:0'
result_dfs = {}

## CIFAR-100

* No Sparsity

In [4]:
densenet = DenseNet3(100, 100, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=None, info=None)
checkpoint = torch.load("./checkpoints/CIFAR-100/densenet/checkpoint_100.pth.tar")
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval()
densenet.to(device);

In [5]:
energy_metrics = BinaryMetrics()
energy = Runner(lambda x: torch.logsumexp(densenet.forward(x), -1)/1000.0, energy_metrics, dataloaders['CIFAR-100'], device)

dict_energy_metrics = {}
for nm_dl in ood_dls:
  dict_energy_metrics[nm_dl] = energy.run(dataloaders[nm_dl])

energy_df = pd.DataFrame(dict_energy_metrics)
energy_df['Avg.'] = energy_df.mean(axis=1)
result_dfs['CIFAR-100_energy'] = energy_df

# energy_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

* With Sparsity $p=90$

In [6]:
info = np.load(f"./cache/CIFAR-100_densenet_feat_stat.npy")

densenet = DenseNet3(100, 100, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=90, info=info)
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval()
densenet.to(device);

In [7]:
dice_metrics = BinaryMetrics()
dice = Runner(lambda x: torch.logsumexp(densenet.forward(x), -1)/1000.0, dice_metrics, dataloaders['CIFAR-100'], device)

dict_dice_metrics = {}
for nm_dl in ood_dls:
  dict_dice_metrics[nm_dl] = dice.run(dataloaders[nm_dl])

dice_df = pd.DataFrame(dict_dice_metrics)
dice_df['Avg.'] = dice_df.mean(axis=1)
result_dfs['CIFAR-100_dice'] = dice_df

# dice_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

## CIFAR-10

* No Sparsity

In [8]:
densenet = DenseNet3(100, 10, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=None, info=None)
checkpoint = torch.load("./checkpoints/CIFAR-10/densenet/checkpoint_100.pth.tar")
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval()
densenet.to(device);

In [9]:
energy_metrics = BinaryMetrics()
energy = Runner(lambda x: torch.logsumexp(densenet.forward(x), -1)/1000.0, energy_metrics, dataloaders['CIFAR-10'], device)

dict_energy_metrics = {}
for nm_dl in ood_dls:
  dict_energy_metrics[nm_dl] = energy.run(dataloaders[nm_dl])

energy_df = pd.DataFrame(dict_energy_metrics)
energy_df['Avg.'] = energy_df.mean(axis=1)
result_dfs['CIFAR-10_energy'] = energy_df

# energy_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

In [10]:
# msp_metrics = BinaryMetrics()
# msp = Runner(lambda x: F.softmax(model.forward(x), -1).max(-1)[0], msp_metrics, dataloaders['CIFAR-10'], device)

# dict_msp_metrics = {}
# for nm_dl in ood_dls:
#   dict_msp_metrics[nm_dl] = msp.run(dataloaders[nm_dl])

# msp_df = pd.DataFrame(dict_msp_metrics)
# msp_df['Avg.'] = msp_df.mean(axis=1)
# result_dfs['CIFAR-10_msp'] = msp_df

# msp_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

* With Sparsity $p=90$

In [11]:
info = np.load(f"./cache/CIFAR-10_densenet_feat_stat.npy")

densenet = DenseNet3(100, 10, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=90, info=info)
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval()
densenet.to(device);

In [12]:
dice_metrics = BinaryMetrics()
dice = Runner(lambda x: torch.logsumexp(densenet.forward(x), -1)/1000.0, dice_metrics, dataloaders['CIFAR-10'], device)

dict_dice_metrics = {}
for nm_dl in ood_dls:
  dict_dice_metrics[nm_dl] = dice.run(dataloaders[nm_dl])

dice_df = pd.DataFrame(dict_dice_metrics)
dice_df['Avg.'] = dice_df.mean(axis=1)
result_dfs['CIFAR-10_dice'] = dice_df

# dice_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

## Results

In [13]:
result_dfs['CIFAR-100_energy'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.8754,0.818515,0.863067
LSUN,0.1478,0.974348,0.976312
LSUN_resize,0.7073,0.801376,0.81424
iSUN,0.746106,0.789539,0.817554
dtd,0.843617,0.710096,0.765005
places365,0.782843,0.781495,0.14173
Avg.,0.683844,0.812562,0.729651


In [14]:
result_dfs['CIFAR-100_dice'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.5946,0.885669,0.903368
LSUN,0.0093,0.997393,0.997414
LSUN_resize,0.5176,0.893207,0.902982
iSUN,0.496359,0.895126,0.910083
dtd,0.615603,0.771157,0.793482
places365,0.803705,0.774739,0.140498
Avg.,0.506194,0.869548,0.774638


In [15]:
result_dfs['CIFAR-10_energy'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.4063,0.939926,0.95334
LSUN,0.0381,0.991504,0.992562
LSUN_resize,0.0928,0.981238,0.984918
iSUN,0.100616,0.980683,0.986024
dtd,0.563121,0.864202,0.898692
places365,0.397665,0.918054,0.385097
Avg.,0.266434,0.945935,0.866772


In [16]:
result_dfs['CIFAR-10_dice'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.297,0.946658,0.952062
LSUN,0.0038,0.998986,0.998997
LSUN_resize,0.0447,0.990291,0.991147
iSUN,0.051541,0.989713,0.991639
dtd,0.45922,0.869672,0.888168
places365,0.451291,0.901546,0.333267
Avg.,0.217925,0.949478,0.859213
