### Steps:

* Train LeNet5 model on mnist
* Use ODIN postprocessor on mnistood
* Get evaluation metrics
* Load iWildsCam dataset
* Trian on Resnet
* Use ODIN postprocessor

In [161]:
import torch

from torchvision.datasets import mnist, FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.optim import SGD
from torch.nn import Module
from torch import nn
from torch.nn import CrossEntropyLoss
from torchvision.models.resnet import Bottleneck, ResNet

import numpy as np

from openood.evaluators import metrics

### Supported Activation Functions

For activation functions, we are considering ReLU, Softplus, Swish. *Note that we may conduct experiments for a subset based on the compute resources available*

In [165]:
def get_activation_fn(activation):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'softplus':
        return nn.Softplus()
    elif activation == 'swish':
        return nn.Swish()
    return None

### Supported Networks

Currently, we support LeNet and ResNet50.

In [166]:
class LeNet(nn.Module):
    def __init__(self, num_classes, num_channel=3, activation='relu'):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.feature_size = 84
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=num_channel,
                      out_channels=6,
                      kernel_size=5,
                      stride=1,
                      padding=2), get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
             get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=120,
                      kernel_size=5,
                      stride=1), get_activation_fn(activation))

        self.classifier1 = nn.Linear(in_features=120, out_features=84)
        self.relu = get_activation_fn(activation)
        self.fc = nn.Linear(in_features=84, out_features=num_classes)

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()

    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        logits_cls = self.fc(feature)
        feature_list = [feature1, feature2, feature3, feature]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        feature = feature.clip(max=threshold)
        logits_cls = self.fc(feature)

        return logits_cls

### Supported Post-Hoc OODN Processors

In [170]:
class ODINPostprocessor():
    def __init__(self, temperature, noise):
        self.temperature = temperature
        self.noise = noise
        
    def postprocess(self, net: nn.Module, data):
        net.eval()
        data.requires_grad = True
        output = net(data)

        # Calculating the perturbation we need to add, that is,
        # the sign of gradient of cross entropy loss w.r.t. input
        criterion = nn.CrossEntropyLoss()

        labels = output.detach().argmax(axis=1)

        # Using temperature scaling
        output = output / self.temperature

        loss = criterion(output, labels)
        loss.backward()

        # Normalizing the gradient to binary in {0, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2

        # Scaling values taken from original code
        gradient[:, 0] = (gradient[:, 0]) / (63.0 / 255.0)
#         gradient[:, 1] = (gradient[:, 1]) / (62.1 / 255.0)
#         gradient[:, 2] = (gradient[:, 2]) / (66.7 / 255.0)

        # Adding small perturbations to images
        tempInputs = torch.add(data.detach(), gradient, alpha=-self.noise)
        output = net(tempInputs)
        output = output / self.temperature

        # Calculating the confidence after adding perturbations
        nnOutput = output.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)

        conf, pred = nnOutput.max(dim=1)

        return pred, conf
    
    def inference(self, net: nn.Module, data_loader: DataLoader):
        pred_list, conf_list, label_list = [], [], []
        for idx, (data, label) in enumerate(data_loader):
            pred, conf = self.postprocess(net, data)
            for idx in range(len(data)):
                pred_list.append(pred[idx].cpu().tolist())
                conf_list.append(conf[idx].cpu().tolist())
                label_list.append(label[idx].cpu().tolist())

        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        conf_list = np.array(conf_list)
        label_list = np.array(label_list, dtype=int)

        return pred_list, conf_list, label_list

### Supported Out of Distribution Detection Metrics

In [171]:
def print_formatted_metrics(metrics, dataset_name):
    [fpr, auroc, aupr_in, aupr_out,
     ccr_4, ccr_3, ccr_2, ccr_1, accuracy] \
     = metrics

    write_content = {
        'dataset': dataset_name,
        'FPR@95': '{:.2f}'.format(100 * fpr),
        'AUROC': '{:.2f}'.format(100 * auroc),
        'AUPR_IN': '{:.2f}'.format(100 * aupr_in),
        'AUPR_OUT': '{:.2f}'.format(100 * aupr_out),
        'CCR_4': '{:.2f}'.format(100 * ccr_4),
        'CCR_3': '{:.2f}'.format(100 * ccr_3),
        'CCR_2': '{:.2f}'.format(100 * ccr_2),
        'CCR_1': '{:.2f}'.format(100 * ccr_1),
        'ACC': '{:.2f}'.format(100 * accuracy)
    }

    fieldnames = list(write_content.keys())

    # print ood metric results
    print('FPR@95: {:.2f}, AUROC: {:.2f}'.format(100 * fpr, 100 * auroc),
          end=' ',
          flush=True)
    print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(
        100 * aupr_in, 100 * aupr_out),
          flush=True)
    print('CCR: {:.2f}, {:.2f}, {:.2f}, {:.2f},'.format(
        ccr_4 * 100, ccr_3 * 100, ccr_2 * 100, ccr_1 * 100),
          end=' ',
          flush=True)
    print('ACC: {:.2f}'.format(accuracy * 100), flush=True)
    print(u'\u2500' * 70, flush=True)

In [167]:
class ResNet50(ResNet):
    def __init__(self,
                 block=Bottleneck,
                 layers=[3, 4, 6, 3],
                 num_classes=1000):
        super(ResNet50, self).__init__(block=block,
                                       layers=layers,
                                       num_classes=num_classes)
        self.feature_size = 2048


    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.relu(self.bn1(self.conv1(x)))
        feature1 = self.maxpool(feature1)
        feature2 = self.layer1(feature1)
        feature3 = self.layer2(feature2)
        feature4 = self.layer3(feature3)
        feature5 = self.layer4(feature4)
        feature5 = self.avgpool(feature5)
        feature = feature5.view(feature5.size(0), -1)
        logits_cls = self.fc(feature)

        feature_list = [feature1, feature2, feature3, feature4, feature5]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.relu(self.bn1(self.conv1(x)))
        feature1 = self.maxpool(feature1)
        feature2 = self.layer1(feature1)
        feature3 = self.layer2(feature2)
        feature4 = self.layer3(feature3)
        feature5 = self.layer4(feature4)
        feature5 = self.avgpool(feature5)
        feature = feature5.clip(max=threshold)
        feature = feature.view(feature.size(0), -1)
        logits_cls = self.fc(feature)

        return logits_cls

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()

### Full OODN Flow On LeNet5

### MNIST Training and In-Distribution Test Dataset

In [168]:
train_dataset = mnist.MNIST(root='data', download=False, train=True, transform=ToTensor())
test_dataset = mnist.MNIST(root='data', download=False, train=False, transform=ToTensor())

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

### Training LeNet5 On MNIST

In [153]:
model = LeNet(num_classes=10, num_channel=1, activation='softplus')
sgd = SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
loss_fn = CrossEntropyLoss()
all_epoch = 5

for current_epoch in range(all_epoch):
    model.train()
    for idx, (train_x, train_label) in enumerate(train_loader):
        sgd.zero_grad()
        predict_y = model(train_x.float())
        loss = loss_fn(predict_y, train_label.long())
        if idx % 100 == 0:
            print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
        loss.backward()
        sgd.step()

    all_correct_num = 0
    all_sample_num = 0
    model.eval()
    for idx, (test_x, test_label) in enumerate(test_loader):
        predict_y = model(test_x.float()).detach()
        predict_y = np.argmax(predict_y, axis=-1)
        current_correct_num = predict_y == test_label
        all_correct_num += np.sum(current_correct_num.numpy(), axis=-1)
        all_sample_num += current_correct_num.shape[0]
    acc = all_correct_num / all_sample_num
    print('accuracy: {:.2f}'.format(acc))
    
    if current_epoch % 25 == 0:
        torch.save(model, 'models/mnist_{:.2f}.pkl'.format(acc))

idx: 0, loss: 2.423340082168579
idx: 100, loss: 2.3033783435821533
idx: 200, loss: 2.309399366378784
idx: 300, loss: 2.059486150741577
idx: 400, loss: 0.45519348978996277
accuracy: 0.86
idx: 0, loss: 0.42277392745018005
idx: 100, loss: 0.31062746047973633
idx: 200, loss: 0.186587855219841
idx: 300, loss: 0.14288684725761414
idx: 400, loss: 0.527208149433136
accuracy: 0.95
idx: 0, loss: 0.15106496214866638
idx: 100, loss: 0.1727137267589569
idx: 200, loss: 0.152151957154274
idx: 300, loss: 0.11541634798049927
idx: 400, loss: 0.4183972179889679
accuracy: 0.96
idx: 0, loss: 0.1316715031862259
idx: 100, loss: 0.055190905928611755
idx: 200, loss: 0.1544199436903
idx: 300, loss: 0.08166653662919998
idx: 400, loss: 0.2765331268310547
accuracy: 0.98
idx: 0, loss: 0.09350304305553436
idx: 100, loss: 0.051850125193595886
idx: 200, loss: 0.16075889766216278
idx: 300, loss: 0.08697907626628876
idx: 400, loss: 0.23928365111351013
accuracy: 0.97


### Loading OOD Dataset for MNIST - FashionMNIST

In [154]:
fashion_test_dataset = mnist.FashionMNIST(root='data', download=True,train=False,transform=ToTensor())
fashion_test_loader = DataLoader(fashion_test_dataset, batch_size=batch_size)

In [157]:
temperature = 1000
noise = 0.0014
postprocessor = ODINPostprocessor(temperature, noise)

In [158]:
id_pred, id_conf, id_gt = postprocessor.inference(
            model, test_loader)

ood_pred, ood_conf, ood_gt = postprocessor.inference(
    model, fashion_test_loader)

In [159]:
ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
pred = np.concatenate([id_pred, ood_pred])
conf = np.concatenate([id_conf, ood_conf])
label = np.concatenate([id_gt, ood_gt])
ood_metrics = metrics.compute_all_metrics(conf, label, pred)

In [160]:
print_formatted_metrics(ood_metrics, 'fashion_mnist')

FPR@95: 9.65, AUROC: 98.02 AUPR_IN: 98.27, AUPR_OUT: 97.70
CCR: 14.22, 48.70, 80.68, 94.02, ACC: 97.41
──────────────────────────────────────────────────────────────────────


### iWildCam/WILDS Dataset

In [77]:
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms

In [78]:
dataset = get_dataset(dataset="iwildcam", download=True)

In [80]:
# dataset = get_dataset(dataset="iwildcam", download=True)

# # Get the training set
train_data = dataset.get_subset(
    "id_test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)

8154