### Exploring The Impact of Optimizers and Activation Functions On OODN 

In [75]:
from os import listdir

import time
import json
import copy

import torch

from torchvision.datasets import mnist, FashionMNIST, CIFAR10, CIFAR100
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.optim import SGD, Adam
from torch.nn import Module
from torch import nn
from torch.nn import CrossEntropyLoss
from torchvision.models.resnet import Bottleneck, ResNet
from torchvision import datasets, models, transforms

from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

from openood.evaluators import metrics

pd.set_option('display.max_rows', 500)

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [55]:
MODELS_DIR = '/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/'

In [61]:
class LeNet(nn.Module):
    def __init__(self, num_classes, num_channel=3, activation='relu'):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.feature_size = 84
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=num_channel,
                      out_channels=6,
                      kernel_size=5,
                      stride=1,
                      padding=2), get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
             get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=120,
                      kernel_size=5,
                      stride=1), get_activation_fn(activation))

        self.classifier1 = nn.Linear(in_features=120, out_features=84)
        self.relu = get_activation_fn(activation)
        self.fc = nn.Linear(in_features=84, out_features=num_classes)

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()

    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        logits_cls = self.fc(feature)
        feature_list = [feature1, feature2, feature3, feature]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        feature = feature.clip(max=threshold)
        logits_cls = self.fc(feature)

        return logits_cls

### Supported Activation Functions

For activation functions, we are considering ReLU, Softplus, Swish. *Note that we may conduct experiments for a subset based on the compute resources available*

### Supported Post-Hoc OODN Processors

#### The first post processor we consider is ODIN

In [27]:
class OODPostprocessor():
    
    def inference(self, net: nn.Module, data_loader: DataLoader):
        pred_list, conf_list, label_list = [], [], []
        for idx, loaded_data in enumerate(data_loader):
            data, label = loaded_data[0], loaded_data[1]
            if idx % 50 == 0:
                print(f'Performing inference on batch: {idx}')
            pred, conf = self.postprocess(net, data.to(device))
            for idx in range(len(data)):
                pred_list.append(pred[idx].tolist())
                conf_list.append(conf[idx].tolist())
                label_list.append(label[idx].tolist())

        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        conf_list = np.array(conf_list)
        label_list = np.array(label_list, dtype=int)

        return pred_list, conf_list, label_list

In [28]:
class ODINPostprocessor(OODPostprocessor):
    def __init__(self, temperature, noise):
        super(OODPostprocessor)
        self.temperature = temperature
        self.noise = noise
        
    def postprocess(self, net: nn.Module, data):
        net.eval()
        data.requires_grad = True
        output = net(data)

        # Calculating the perturbation we need to add, that is,
        # the sign of gradient of cross entropy loss w.r.t. input
        criterion = nn.CrossEntropyLoss()

        labels = output.detach().argmax(axis=1)

        # Using temperature scaling
        output = output / self.temperature

        loss = criterion(output, labels)
        loss.backward()

        # Normalizing the gradient to binary in {0, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2

        # Scaling values taken from original code       
        gradient[:, 0] = (gradient[:, 0]) / (63.0 / 255.0)
        if gradient.shape[1] == 3:
            gradient[:, 1] = (gradient[:, 1]) / (62.1 / 255.0)
            gradient[:, 2] = (gradient[:, 2]) / (66.7 / 255.0)

        # Adding small perturbations to images
        tempInputs = torch.add(data.detach(), gradient, alpha=-self.noise)
        output = net(tempInputs)
        output = output / self.temperature

        # Calculating the confidence after adding perturbations
        nnOutput = output.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)

        conf, pred = nnOutput.max(dim=1)

        return pred, conf

#### We consider the Maximum Classifier Discrepancy Post OODN method

https://arxiv.org/pdf/1712.02560.pdf

In [29]:
class MCDPostprocessor(OODPostprocessor):
    def __init__(self, samples: int = 30):
        super(OODPostprocessor)
        self.samples = samples  #: number :math:`N` of samples

    def postprocess(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        mode_switch = False
        if not model.training:
            mode_switch = True

            model.train()

            for mod in model.modules():
                # reset batch norm layers.
                # TODO: are there other layers?
                if isinstance(mod, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                    mod.train(False)

        results = None
        with torch.no_grad():
            for i in range(self.samples):
                output = model(x).softmax(dim=1)
                if results is None:
                    results = torch.zeros(size=output.shape).to(device)
                results += output
        results /= self.samples

        if mode_switch:
            model.eval()
        
        conf, pred = results.max(dim=1)

        return pred, conf

In [30]:
def get_postprocessor(postprocessor_type="odin"):
    if postprocessor_type == "odin":
        postprocessor = ODINPostprocessor(1000, 0.0014)
    elif postprocessor_type == "mcd":
        postprocessor = MCDPostprocessor(30)
    return postprocessor

### Supported Out of Distribution Detection Metrics

What metrics do we specifically care about here?

**FPR@95** measures the false positive rate (FPR) when the true positive rate (TPR) is
equal to 95%. Lower scores indicate better performance. 

**AUROC** measures the area under the
Receiver Operating Characteristic (ROC) curve, which displays the relationship between TPR and
FPR. The area under the ROC curve can be interpreted as the probability that a positive ID example
will have a higher detection score than a negative OOD example. 

**AUPR** measures the area under
the Precision-Recall (PR) curve. The PR curve is created by plotting precision versus recall. Similar
to AUROC, we consider ID samples as positive, so that the score corresponds to the AUPR-In metric
in some works

In [31]:
def calculate_oodn_metrics(model, postprocessor_type, id_test_loader, ood_test_loader, ood_name):
    postprocessor = get_postprocessor(postprocessor_type)
    id_pred, id_conf, id_gt = postprocessor.inference(
                model, id_test_loader)

    ood_pred, ood_conf, ood_gt = postprocessor.inference(
        model, ood_test_loader)

    ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
    pred = np.concatenate([id_pred, ood_pred])
    conf = np.concatenate([id_conf, ood_conf])
    label = np.concatenate([id_gt, ood_gt])
    ood_metrics = metrics.compute_all_metrics(conf, label, pred)

    return print_and_get_formatted_metrics(ood_metrics, ood_name)

def print_and_get_formatted_metrics(metrics, dataset_name):
    [fpr, auroc, aupr_in, aupr_out,
     ccr_4, ccr_3, ccr_2, ccr_1, accuracy] \
     = metrics

    write_content = {
        'dataset': dataset_name,
        'FPR@95': '{:.2f}'.format(100 * fpr),
        'AUROC': '{:.2f}'.format(100 * auroc),
        'AUPR_IN': '{:.2f}'.format(100 * aupr_in),
        'AUPR_OUT': '{:.2f}'.format(100 * aupr_out),
        'CCR_4': '{:.2f}'.format(100 * ccr_4),
        'CCR_3': '{:.2f}'.format(100 * ccr_3),
        'CCR_2': '{:.2f}'.format(100 * ccr_2),
        'CCR_1': '{:.2f}'.format(100 * ccr_1),
        'ACC': '{:.2f}'.format(100 * accuracy)
    }

    fieldnames = list(write_content.keys())

    # print ood metric results
    print('FPR@95: {:.2f}, AUROC: {:.2f}'.format(100 * fpr, 100 * auroc),
          end=' ',
          flush=True)
    print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(
        100 * aupr_in, 100 * aupr_out),
          flush=True)
    print('CCR: {:.2f}, {:.2f}, {:.2f}, {:.2f},'.format(
        ccr_4 * 100, ccr_3 * 100, ccr_2 * 100, ccr_1 * 100),
          end=' ',
          flush=True)
    print('ACC: {:.2f}'.format(accuracy * 100), flush=True)
    print(u'\u2500' * 70, flush=True)
    return write_content

def load_results_into_df(dir_path):
    res_files = [dir_path+each for each in listdir(dir_path)]
    all_results = []
    columns = ['optimizer_type', 'activation_function_type', 'postprocessor_type', 'trial', 'AUROC', 'ACC']
    for fp in res_files:
        f = open(fp)
        data = json.load(f)
        for trial, results in data.items():
            all_results.append([
                    results['optimizer_type'],
                    results['activation_function_type'],
                    results['postprocessor_type'],
                    trial,
                    float(results['AUROC']),
                    float(results['ACC'])
                ])
    df = pd.DataFrame(all_results, columns=columns)
    return df

In [44]:
def get_wilds_loader(dataset, split, batch_size):
    d = dataset.get_subset(
        split,
        # frac=0.1,
        transform=transforms.Compose(
            [transforms.Resize((448, 448)), transforms.ToTensor()]
        ),
    )
    # Prepare the standard data loader
    return get_train_loader("standard", d, batch_size=batch_size, num_workers=4)

def get_data_loader_for_oodn_analysis(dataset_name, batch_size):
    if dataset_name == "cifar10":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        test_dataset = CIFAR10(root='data', download=True, train=False, transform=transform_test)
        return DataLoader(test_dataset, batch_size=batch_size, num_workers=4)
    elif dataset_name == "cifar100":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        test_dataset = CIFAR100(root='data', download=True, train=False, transform=transform_test)
        return DataLoader(test_dataset, batch_size=batch_size, num_workers=4)
    elif dataset_name == "fashionmnist":
        test_dataset = mnist.FashionMNIST(root='data', download=True,train=False,transform=ToTensor())
        return DataLoader(test_dataset, batch_size=batch_size, num_workers=4)
    elif dataset_name == "mnist":
        test_dataset = mnist.MNIST(root='data', download=True, train=False, transform=ToTensor())
        return DataLoader(test_dataset, batch_size=batch_size, num_workers=4)
    elif dataset_name == "iwildcam_id":
        dataset = get_dataset(dataset="iwildcam", download=True)
        return get_wilds_loader(dataset, "id_test", batch_size)
    elif dataset_name == "iwildcam_ood":
        dataset = get_dataset(dataset="iwildcam", download=True)
        return get_wilds_loader(dataset, "test", batch_size)
    return None

In [58]:
def perform_oodn_analysis_on_preloaded_model(model_path, dataset_name, postprocessor_type, id_test, ood_test):
    print(model_path)
    model = torch.load(model_path)
    return calculate_oodn_metrics(model,
                                   postprocessor_type,
                                   id_test,
                                   ood_test,
                                   dataset_name)

def get_model_details_from_path(path):
    dataset, model_type, postprocessor_type, activation, optimizer, trial = path.split('/')[-1].replace(".pkl", "").split("_")
    return dataset, model_type, activation, optimizer, trial



In [91]:
def calculate_ood_test_and_get_results_row(model_path,
                                           id_test_dataset_name,
                                           ood_test_dataset_name,
                                           postprocessors=['odin'],
                                          batch_size=16):
    # TODO: make this efficient by just reusing data loaders between successive runs
    id_loaders = get_data_loader_for_oodn_analysis(id_test_dataset_name, batch_size)
    ood_loaders = get_data_loader_for_oodn_analysis(ood_test_dataset_name, batch_size)
    dataset, model_type, activation, optimizer, trial = get_model_details_from_path(model_path)
    
    rows = []
    for postprocessor_type in postprocessors:
        stats = perform_oodn_analysis_on_preloaded_model(model_path,
                                                         dataset,
                                                         postprocessor_type, 
                                                         id_loaders,
                                                         ood_loaders)
        rows.append([dataset,
                     model_type, 
                     activation,
                     optimizer,
                     int(trial), 
                     ood_test_dataset_name,
                     postprocessor_type,
                     float(stats['AUROC']),
                     float(stats['ACC'])
                    ])

    return rows

In [95]:
def run_oodn_analysis_for_all_models():
    results = []
    for m_name in listdir(MODELS_DIR):
        if "cifar" in m_name:
            id_name, ood_name = 'cifar10', 'cifar100'
        if "mnist" in m_name:
            if "mcd" in m_name: continue
            id_name, ood_name = 'mnist', 'fashionmnist'
        m_path = MODELS_DIR+m_name
        results.extend(calculate_ood_test_and_get_results_row(m_path,
                                               id_name, 
                                               ood_name,
                                               postprocessors=['odin', 'mcd'], 
                                               batch_size=128))
    return results

In [96]:
columns  = ['Dataset', 'Model Type', 'Activation', 
            'Optimizer', 'Trial',
            'OOD Test Name', 
            'PostProcessor', 
            'AUROC', 'ACC']
df = pd.DataFrame(run_oodn_analysis_for_all_models(), columns=columns)

/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/mnist_lenet_odin_softplus_Adam_2.pkl
Performing inference on batch: 0
Performing inference on batch: 50
Performing inference on batch: 0
Performing inference on batch: 50
FPR@95: 9.58, AUROC: 98.39 AUPR_IN: 98.59, AUPR_OUT: 98.28
CCR: 55.18, 69.09, 84.15, 94.21, ACC: 97.71
──────────────────────────────────────────────────────────────────────
/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/mnist_lenet_odin_softplus_Adam_2.pkl
Performing inference on batch: 0
Performing inference on batch: 50
Performing inference on batch: 0
Performing inference on batch: 50
FPR@95: 18.78, AUROC: 97.16 AUPR_IN: 97.60, AUPR_OUT: 96.84
CCR: 47.59, 59.13, 77.00, 91.65, ACC: 97.71
──────────────────────────────────────────────────────────────────────
Files already downloaded and verified
Files already downloaded and verified
/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/cifar_resnet50_odin_softplus_SGD_2.pkl
Performing inference on batch: 0
Pe

CCR: 54.79, 72.82, 87.49, 94.39, ACC: 98.21
──────────────────────────────────────────────────────────────────────
/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/mnist_lenet_odin_relu_SGD_1.pkl
Performing inference on batch: 0
Performing inference on batch: 50
Performing inference on batch: 0
Performing inference on batch: 50
FPR@95: 1.03, AUROC: 99.62 AUPR_IN: 99.66, AUPR_OUT: 99.58
CCR: 49.92, 86.67, 94.39, 97.63, ACC: 98.19
──────────────────────────────────────────────────────────────────────
/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/mnist_lenet_odin_relu_SGD_1.pkl
Performing inference on batch: 0
Performing inference on batch: 50
Performing inference on batch: 0
Performing inference on batch: 50
FPR@95: 4.22, AUROC: 98.93 AUPR_IN: 99.10, AUPR_OUT: 98.79
CCR: 44.01, 79.28, 90.61, 96.16, ACC: 98.19
──────────────────────────────────────────────────────────────────────
Files already downloaded and verified
Files already downloaded and verified
/home/rdr2143/oodn-fina

Performing inference on batch: 50
FPR@95: 90.52, AUROC: 65.58 AUPR_IN: 66.67, AUPR_OUT: 61.92
CCR: 0.00, 0.00, 0.00, 24.61, ACC: 82.83
──────────────────────────────────────────────────────────────────────
Files already downloaded and verified
Files already downloaded and verified
/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/cifar_resnet50_odin_softplus_Adam_0.pkl
Performing inference on batch: 0
Performing inference on batch: 50
Performing inference on batch: 0
Performing inference on batch: 50
FPR@95: 75.67, AUROC: 81.23 AUPR_IN: 82.73, AUPR_OUT: 77.64
CCR: 0.08, 1.76, 20.75, 54.69, ACC: 85.32
──────────────────────────────────────────────────────────────────────
/home/rdr2143/oodn-final-project/OpenOOD-nndl/models/cifar_resnet50_odin_softplus_Adam_0.pkl
Performing inference on batch: 0
Performing inference on batch: 50
Performing inference on batch: 0
Performing inference on batch: 50
FPR@95: 90.59, AUROC: 67.11 AUPR_IN: 68.14, AUPR_OUT: 62.69
CCR: 0.00, 0.00, 0.00, 29.05, A

In [97]:
df.to_csv('oodn_odin_mcd_results_v1.csv')

In [98]:
df.groupby(['Dataset', 'Activation', 'Optimizer', 'PostProcessor'])['AUROC'].describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,count,mean,std,min,25%,50%,75%,max
Dataset,Activation,Optimizer,PostProcessor,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
cifar,relu,Adam,mcd,3.0,65.246667,3.021661,61.77,64.25,66.73,66.985,67.24
cifar,relu,Adam,odin,3.0,80.293333,2.805928,77.06,79.395,81.73,81.91,82.09
cifar,relu,SGD,mcd,3.0,70.603333,0.544457,70.01,70.365,70.72,70.9,71.08
cifar,relu,SGD,odin,3.0,78.37,0.922605,77.49,77.89,78.29,78.81,79.33
cifar,softplus,Adam,mcd,3.0,66.416667,0.775005,65.58,66.07,66.56,66.835,67.11
cifar,softplus,Adam,odin,3.0,80.206667,0.887938,79.64,79.695,79.75,80.49,81.23
cifar,softplus,SGD,mcd,3.0,64.21,0.860174,63.54,63.725,63.91,64.545,65.18
cifar,softplus,SGD,odin,3.0,74.6,4.620855,69.46,72.695,75.93,77.17,78.41
mnist,relu,Adam,mcd,3.0,98.606667,0.718076,98.11,98.195,98.28,98.855,99.43
mnist,relu,Adam,odin,3.0,99.47,0.338083,99.26,99.275,99.29,99.575,99.86


In [99]:
df

Unnamed: 0,Dataset,Model Type,Activation,Optimizer,Trial,OOD Test Name,PostProcessor,AUROC,ACC
0,mnist,lenet,softplus,Adam,2,fashionmnist,odin,98.39,97.71
1,mnist,lenet,softplus,Adam,2,fashionmnist,mcd,97.16,97.71
2,cifar,resnet50,softplus,SGD,2,cifar100,odin,78.41,75.58
3,cifar,resnet50,softplus,SGD,2,cifar100,mcd,65.18,76.6
4,cifar,resnet50,relu,SGD,1,cifar100,odin,77.49,78.18
5,cifar,resnet50,relu,SGD,1,cifar100,mcd,70.01,76.36
6,cifar,resnet50,relu,Adam,2,cifar100,odin,81.73,84.94
7,cifar,resnet50,relu,Adam,2,cifar100,mcd,66.73,82.64
8,cifar,resnet50,relu,SGD,0,cifar100,odin,78.29,78.59
9,cifar,resnet50,relu,SGD,0,cifar100,mcd,70.72,77.19
