# TEMP

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

from models.densenet import DenseNet3
import util.svhn_loader as svhn

In [16]:
transform_cifar = transforms.Compose([
  transforms.Resize(32),
  transforms.CenterCrop(32),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

datasets = {
  'CIFAR-10': torchvision.datasets.CIFAR10(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar),
  'CIFAR-100': torchvision.datasets.CIFAR100(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar), 
  'SVHN': svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transform_cifar, download=False),
  'dtd': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_cifar),
  'places365': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/places365/", transform=transform_cifar),
  'celebA': torchvision.datasets.CelebA(root='datasets/ood_datasets/', split='test', download=True, transform=transform_cifar),
  'iSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/iSUN", transform=transform_cifar),
  'LSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN", transform=transform_cifar),
  'LSUN_resize': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN_resize", transform=transform_cifar),
}

dataloaders = {
  k: torch.utils.data.DataLoader(v, batch_size=512, shuffle=False) for k,v in datasets.items()
}

ood_dls = ['SVHN', 'LSUN', 'LSUN_resize', 'iSUN', 'dtd', 'places365']

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
import sys
sys.path.append('./code/')

from metrics import BinaryMetrics, Runner
from stats import Stats
from dice import DICE

device = 'cuda:1'
result_dfs = {}

## CIFAR-100

In [4]:
densenet = DenseNet3(100, 100, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=None, info=None)
checkpoint = torch.load("./checkpoints/CIFAR-100/densenet/checkpoint_100.pth.tar", map_location=device)
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval();

In [5]:
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm.auto import tqdm

create_feature_extractor(densenet, {'view': 'feature', 'fc': 'logit'}).graph.print_tabular()

opcode         name                    target                                                  args                                       kwargs
-------------  ----------------------  ------------------------------------------------------  -----------------------------------------  --------
placeholder    x                       x                                                       ()                                         {}
call_module    conv1                   conv1                                                   (x,)                                       {}
call_module    block1_layer_0_bn1      block1.layer.0.bn1                                      (conv1,)                                   {}
call_module    block1_layer_0_relu     block1.layer.0.relu                                     (block1_layer_0_bn1,)                      {}
call_module    block1_layer_0_conv1    block1.layer.0.conv1                                    (block1_layer_0_relu,)                     {}
cal

In [6]:
class DynamicTemp(nn.Module):
  def __init__(self, model, device='cuda:0', eps=torch.finfo(torch.float32).eps):
    super(self.__class__, self).__init__()
    model.eval()
    model.to(device)
    self.model = create_feature_extractor(model, {'view': 'feature', 'fc': 'logit'})
    for p in self.model.parameters():
      p.requires_grad_(False)
    self.temp1 = nn.Linear(342, 128, device=device)
    self.temp2 = nn.Linear(128, 1, device=device)
    self.eps = eps

  def temperature(self, x):
    x = self.model(x)['feature']
    x = F.relu(self.temp1(x))
    x = F.relu(self.temp2(x))
    return x + self.eps

  def forward(self, x):
    res = self.model(x)
    x = F.relu(self.temp1(res['feature']))
    x = F.relu(self.temp2(x)) + self.eps
    return res['logit']/x
  
model = DynamicTemp(densenet, device=device)

In [9]:
K = 100
epochs = 10

if K == 10:
  ds = torchvision.datasets.CIFAR10(root='./datasets/id_datasets/', train=True, download=True, transform=transform_cifar)
elif K == 100:
  ds = torchvision.datasets.CIFAR100(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar)
else:
  raise

dl = torch.utils.data.DataLoader(ds, batch_size=512, shuffle=True)

optimizer = optim.AdamW(model.parameters(), lr=1e-5)

for e in range(epochs):
  with tqdm(dl) as pbar:
    for x, y in pbar:
        e = torch.rand(x.shape[0])
        noise = torch.randn(x.shape)
        x = (1-e).view(-1,1,1,1)*x + e.view(-1,1,1,1)*noise
        x = x/torch.sqrt((1-e)**2 + e**2).view(-1,1,1,1)
        y = (1-e).view(-1,1)*F.one_hot(y, K) + e.view(-1,1)/K

        y_hat = model(x.to(device))
        loss = F.cross_entropy(y_hat, y.to(device))

        model.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({'loss': loss.item()})

Files already downloaded and verified


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

* Energy

In [17]:
energy_metrics = BinaryMetrics()
energy = Runner(lambda x: torch.logsumexp(model(x), -1)/1000.0, energy_metrics, dataloaders['CIFAR-10'], device)

dict_energy_metrics = {}
for nm_dl in ood_dls:
  dict_energy_metrics[nm_dl] = energy.run(dataloaders[nm_dl])

energy_df = pd.DataFrame(dict_energy_metrics)
energy_df['Avg.'] = energy_df.mean(axis=1)
result_dfs['CIFAR-100_energy'] = energy_df

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

* MSP

In [18]:
msp_metrics = BinaryMetrics()
msp = Runner(lambda x: F.softmax(model(x), -1).max(dim=-1)[0], msp_metrics, dataloaders['CIFAR-100'], device)

dict_msp_metrics = {}
for nm_dl in ood_dls:
  dict_msp_metrics[nm_dl] = msp.run(dataloaders[nm_dl])

msp_df = pd.DataFrame(dict_msp_metrics)
msp_df['Avg.'] = msp_df.mean(axis=1)
result_dfs['CIFAR-100_msp'] = msp_df

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

* Temperature

In [19]:
temp_metrics = BinaryMetrics()
temp = Runner(lambda x: 1/model.temperature(x).flatten(), temp_metrics, dataloaders['CIFAR-100'], device)

dict_temp_metrics = {}
for nm_dl in ood_dls:
  dict_temp_metrics[nm_dl] = temp.run(dataloaders[nm_dl])

temp_df = pd.DataFrame(dict_temp_metrics)
temp_df['Avg.'] = temp_df.mean(axis=1)
result_dfs['CIFAR-100_temp'] = temp_df

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/642 [00:00<?, ?it/s]

## Results

In [20]:
result_dfs['CIFAR-100_energy'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.9977,0.458282,0.542955
LSUN,0.937,0.529611,0.543413
LSUN_resize,0.2369,0.949753,0.948275
iSUN,0.27126,0.932952,0.931982
dtd,0.686348,0.726083,0.792639
places365,0.866627,0.605802,0.039474
Avg.,0.665973,0.700414,0.633123


In [21]:
result_dfs['CIFAR-100_msp'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.8657,0.737108,0.758497
LSUN,0.828,0.74131,0.754353
LSUN_resize,0.4081,0.914934,0.923912
iSUN,0.400112,0.913741,0.927265
dtd,0.620745,0.832412,0.894439
places365,0.752499,0.771134,0.130909
Avg.,0.645859,0.81844,0.731563


In [22]:
result_dfs['CIFAR-100_temp'].T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.9991,0.310502,0.405907
LSUN,1.0,0.0142,0.307416
LSUN_resize,0.2756,0.927263,0.912987
iSUN,0.296583,0.916194,0.908679
dtd,0.657979,0.61802,0.668804
places365,0.929102,0.527516,0.03221
Avg.,0.693061,0.552282,0.539334
