# DICE (std. not avg.)
[DICE: Leveraging Sparsification for Out-of-Distribution Detection](https://arxiv.org/abs/2111.09805)

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from models.densenet import DenseNet3
import util.svhn_loader as svhn

In [2]:
transform_cifar = transforms.Compose([
  transforms.Resize(32),
  transforms.CenterCrop(32),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

datasets = {
  'CIFAR-10': torchvision.datasets.CIFAR10(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar),
  'CIFAR-100': torchvision.datasets.CIFAR100(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar), 
  'SVHN': svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transform_cifar, download=False),
  'dtd': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_cifar),
  'places365': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/places365/", transform=transform_cifar),
  'celebA': torchvision.datasets.CelebA(root='datasets/ood_datasets/', split='test', download=True, transform=transform_cifar),
  'iSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/iSUN", transform=transform_cifar),
  'LSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN", transform=transform_cifar),
  'LSUN_resize': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN_resize", transform=transform_cifar),
}

dataloaders = {
  k: torch.utils.data.DataLoader(v, batch_size=512, shuffle=False) for k,v in datasets.items()
}

ood_dls = ['SVHN', 'LSUN', 'LSUN_resize', 'iSUN', 'dtd', 'places365']

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
import sys
sys.path.append('./code/')

from metrics import BinaryMetrics, Runner
from stats import Stats
from dice import DICE

device = 'cuda:0'
result_dfs = {}

## CIFAR-100

In [4]:
densenet = DenseNet3(100, 100, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=None, info=None)
checkpoint = torch.load("./checkpoints/CIFAR-100/densenet/checkpoint_100.pth.tar", map_location=device)
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval()

model = DICE(densenet, device, mode='energy')

* With Sparsity $p=90$

In [5]:
return_nodes = {
  'view': 'feature',
  # 'fc': 'logit',
}
s = Stats(densenet, return_nodes, device)
s.run(dataloaders['CIFAR-100'])

# avg_features = torch.stack([s.compute(target=i)['feature'] for i in range(100)])
# model.set_dice_(avg_features.to(device), p=90)
# std_features = torch.stack([s.compute(target=i, std=True)[1]['feature'] for i in range(100)])
# model.set_dice_(std_features.to(device), p=90)
std_features = s.compute(std=True)[1]['feature']
model.set_dice_(std_features.to(device), p=90)

  0%|          | 0/20 [00:00<?, ?it/s]

In [6]:
dice_metrics = BinaryMetrics()
dice = Runner(model, dice_metrics, dataloaders['CIFAR-100'], device)

dict_dice_metrics = {}
for nm_dl in ood_dls:
  dict_dice_metrics[nm_dl] = dice.run(dataloaders[nm_dl])

dice_df = pd.DataFrame(dict_dice_metrics)
dice_df['Avg.'] = dice_df.mean(axis=1)
result_dfs['CIFAR-100_dice'] = dice_df

# dice_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

## CIFAR-10

In [7]:
densenet = DenseNet3(100, 10, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=None, info=None)
checkpoint = torch.load("./checkpoints/CIFAR-10/densenet/checkpoint_100.pth.tar", map_location=device)
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval()

model = DICE(densenet, device, mode='energy')

* With Sparsity $p=90$

In [8]:
return_nodes = {
  'view': 'feature',
  # 'fc': 'logit',
}
s = Stats(densenet, return_nodes, device)
s.run(dataloaders['CIFAR-10'])

# avg_features = torch.stack([s.compute(target=i)['feature'] for i in range(10)])
# model.set_dice_(avg_features.to(device), p=90)
# std_features = torch.stack([s.compute(target=i, std=True)[1]['feature'] for i in range(10)])
# model.set_dice_(std_features.to(device), p=90)
std_features = s.compute(std=True)[1]['feature']
model.set_dice_(std_features.to(device), p=90)

  0%|          | 0/20 [00:00<?, ?it/s]

In [9]:
dice_metrics = BinaryMetrics()
dice = Runner(model, dice_metrics, dataloaders['CIFAR-10'], device)

dict_dice_metrics = {}
for nm_dl in ood_dls:
  dict_dice_metrics[nm_dl] = dice.run(dataloaders[nm_dl])

dice_df = pd.DataFrame(dict_dice_metrics)
dice_df['Avg.'] = dice_df.mean(axis=1)
result_dfs['CIFAR-10_dice'] = dice_df

# dice_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

## Results

In [10]:
result_dfs['CIFAR-100_dice'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.6078,0.884362,0.902689
LSUN,0.0095,0.997497,0.997506
LSUN_resize,0.5209,0.892435,0.902281
iSUN,0.493894,0.895941,0.910864
dtd,0.604433,0.775184,0.796065
places365,0.793306,0.782958,0.148863
Avg.,0.504972,0.871396,0.776378


In [11]:
result_dfs['CIFAR-10_dice'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.2957,0.946619,0.952029
LSUN,0.0038,0.998989,0.998999
LSUN_resize,0.0439,0.990308,0.991158
iSUN,0.051317,0.989727,0.991648
dtd,0.457092,0.869977,0.888359
places365,0.449796,0.901538,0.332931
Avg.,0.216934,0.949526,0.859187
