# CLS

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

from models.densenet import DenseNet3
import util.svhn_loader as svhn

In [2]:
transform_cifar = transforms.Compose([
  transforms.Resize(32),
  transforms.CenterCrop(32),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

datasets = {
  'CIFAR-10': torchvision.datasets.CIFAR10(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar),
  'CIFAR-100': torchvision.datasets.CIFAR100(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar), 
  'SVHN': svhn.SVHN('datasets/ood_datasets/svhn/', split='test', transform=transform_cifar, download=False),
  'dtd': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_cifar),
  'places365': torchvision.datasets.ImageFolder(root="datasets/ood_datasets/places365/", transform=transform_cifar),
  'celebA': torchvision.datasets.CelebA(root='datasets/ood_datasets/', split='test', download=True, transform=transform_cifar),
  'iSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/iSUN", transform=transform_cifar),
  'LSUN': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN", transform=transform_cifar),
  'LSUN_resize': torchvision.datasets.ImageFolder("./datasets/ood_datasets/LSUN_resize", transform=transform_cifar),
}

dataloaders = {
  k: torch.utils.data.DataLoader(v, batch_size=512, shuffle=False) for k,v in datasets.items()
}

# ood_dls = ['SVHN', 'LSUN', 'LSUN_resize', 'iSUN', 'dtd', 'places365']
ood_dls = ['SVHN', 'LSUN', 'LSUN_resize', 'iSUN', 'dtd']

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
import sys
sys.path.append('./code/')

from metrics import BinaryMetrics, Runner
from stats import Stats
from dice import DICE

device = 'cuda:0'
result_dfs = {}

## CIFAR-100

In [4]:
densenet = DenseNet3(100, 100, 12, reduction=0.5, bottleneck=True, dropRate=0.0, normalizer=None, p=None, info=None)
checkpoint = torch.load("./checkpoints/CIFAR-100/densenet/checkpoint_100.pth.tar", map_location=device)
densenet.load_state_dict(checkpoint['state_dict'])
densenet.eval();

In [5]:
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm.auto import tqdm

create_feature_extractor(densenet, {'view': 'feature', 'fc': 'logit'}).graph.print_tabular()

opcode         name                    target                                                  args                                       kwargs
-------------  ----------------------  ------------------------------------------------------  -----------------------------------------  --------
placeholder    x                       x                                                       ()                                         {}
call_module    conv1                   conv1                                                   (x,)                                       {}
call_module    block1_layer_0_bn1      block1.layer.0.bn1                                      (conv1,)                                   {}
call_module    block1_layer_0_relu     block1.layer.0.relu                                     (block1_layer_0_bn1,)                      {}
call_module    block1_layer_0_conv1    block1.layer.0.conv1                                    (block1_layer_0_relu,)                     {}
cal

In [6]:
class NewCLS(nn.Module):
  def __init__(self, model, device='cuda:0'):
    super(self.__class__, self).__init__()
    model.eval()
    model.to(device)
    self.model = create_feature_extractor(model, {'view': 'feature', 'fc': 'logit'})
    for p in self.model.parameters():
      p.requires_grad_(False)
    self.new_cls = nn.Linear(342, 1, device=device)

  def forward(self, x):
    res = self.model(x)
    x = self.new_cls(res['feature'])
    return torch.concat([res['logit'], x], dim=1)
  
model = NewCLS(densenet, device=device)

In [9]:
K = 100
epochs = 10

if K == 10:
  ds = torchvision.datasets.CIFAR10(root='./datasets/id_datasets/', train=True, download=True, transform=transform_cifar)
elif K == 100:
  ds = torchvision.datasets.CIFAR100(root='./datasets/id_datasets/', train=False, download=True, transform=transform_cifar)
else:
  raise

dl = torch.utils.data.DataLoader(ds, batch_size=512, shuffle=True)

optimizer = optim.AdamW(model.parameters(), lr=1e-5)

for e in range(epochs):
  with tqdm(dl) as pbar:
    for x, y in pbar:
        e = torch.rand((x.shape[0],))
        noise = torch.randn(x.shape)
        x = (1-e).view(-1,1,1,1)*x + e.view(-1,1,1,1)*noise
        x = x/torch.sqrt((1-e)**2 + e**2).view(-1,1,1,1)
        y = torch.concat([(1-e).view(-1,1)*F.one_hot(y, K), e.view(-1,1)], dim=1)

        y_hat = model(x.to(device))
        loss = F.cross_entropy(y_hat, y.to(device))

        model.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({'loss': loss.item()})

Files already downloaded and verified


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [10]:
cls_metrics = BinaryMetrics()
cls = Runner(lambda x: 1-F.softmax(model(x), -1)[-1], cls_metrics, dataloaders['CIFAR-100'], device)

dict_cls_metrics = {}
for nm_dl in ood_dls:
  dict_cls_metrics[nm_dl] = cls.run(dataloaders[nm_dl])

cls_df = pd.DataFrame(dict_cls_metrics)
cls_df['Avg.'] = cls_df.mean(axis=1)
result_dfs['CIFAR-100_cls'] = cls_df

In-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]



Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/20 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/18 [00:00<?, ?it/s]

Out-of-dist:   0%|          | 0/12 [00:00<?, ?it/s]

In [11]:
cls_df.T[['FPR@95', 'AUROC', 'AUPR_In']]

Unnamed: 0,FPR@95,AUROC,AUPR_In
SVHN,0.821287,0.719022,0.781617
LSUN,0.560891,0.91234,0.93431
LSUN_resize,0.85198,0.732777,0.790774
iSUN,0.823982,0.752149,0.816728
dtd,0.846535,0.675339,0.826769
Avg.,0.780935,0.758326,0.83004
