### Exploring The Impact of Optimizers and Activation Functions On OODN 

In [159]:
import time

import torch

from torchvision.datasets import mnist, FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.optim import SGD, Adam
from torch.nn import Module
from torch import nn
from torch.nn import CrossEntropyLoss
from torchvision.models.resnet import Bottleneck, ResNet

import numpy as np

from openood.evaluators import metrics

In [61]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


### Supported Activation Functions

For activation functions, we are considering ReLU, Softplus, Swish. *Note that we may conduct experiments for a subset based on the compute resources available*

In [4]:
def get_activation_fn(activation):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'softplus':
        return nn.Softplus()
    elif activation == 'swish':
        return nn.Swish()
    return None

### Supported Networks

Currently, we support LeNet and ResNet50.

### LeNet

In [5]:
class LeNet(nn.Module):
    def __init__(self, num_classes, num_channel=3, activation='relu'):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.feature_size = 84
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=num_channel,
                      out_channels=6,
                      kernel_size=5,
                      stride=1,
                      padding=2), get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
             get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=120,
                      kernel_size=5,
                      stride=1), get_activation_fn(activation))

        self.classifier1 = nn.Linear(in_features=120, out_features=84)
        self.relu = get_activation_fn(activation)
        self.fc = nn.Linear(in_features=84, out_features=num_classes)

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()

    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        logits_cls = self.fc(feature)
        feature_list = [feature1, feature2, feature3, feature]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        feature = feature.clip(max=threshold)
        logits_cls = self.fc(feature)

        return logits_cls

### ResNet50

In [200]:
class ResNet50(ResNet):
    def __init__(self,
                 block=Bottleneck,
                 layers=[3, 4, 6, 3],
                 num_classes=1000):
        super(ResNet50, self).__init__(block=block,
                                       layers=layers,
                                       num_classes=num_classes)
        self.feature_size = 2048


    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.relu(self.bn1(self.conv1(x)))
        feature1 = self.maxpool(feature1)
        feature2 = self.layer1(feature1)
        feature3 = self.layer2(feature2)
        feature4 = self.layer3(feature3)
        feature5 = self.layer4(feature4)
        feature5 = self.avgpool(feature5)
        feature = feature5.view(feature5.size(0), -1)
        logits_cls = self.fc(feature)

        feature_list = [feature1, feature2, feature3, feature4, feature5]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.relu(self.bn1(self.conv1(x)))
        feature1 = self.maxpool(feature1)
        feature2 = self.layer1(feature1)
        feature3 = self.layer2(feature2)
        feature4 = self.layer3(feature3)
        feature5 = self.layer4(feature4)
        feature5 = self.avgpool(feature5)
        feature = feature5.clip(max=threshold)
        feature = feature.view(feature.size(0), -1)
        logits_cls = self.fc(feature)

        return logits_cls

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()
    
def get_resnet_model(activation_function_type, n_classes):
    resnet_model = ResNet50(num_classes=n_classes)
    resnet_model.to(device)    
    return resnet_model

def set_activation_function(resnet_model, activation_function_type):
    resnet_model.relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[2].relu = get_activation_fn(activation_function_type)

    resnet_model.layer2[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[3].relu = get_activation_fn(activation_function_type)

    resnet_model.layer3[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[3].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[4].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[5].relu = get_activation_fn(activation_function_type)


    resnet_model.layer4[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[2].relu = get_activation_fn(activation_function_type)
    
    return resnet_model

In [201]:
def get_model(config):
    activation_function_type = config["activation_function_type"]
    network_type = config["network"]
    n_classes = config["n_classes"]
    
    if network_type == "lenet":
        model =  LeNet(num_classes=n_classes, num_channel=1, activation=activation_function_type)
    elif network_type == "resnet50":
        model = get_resnet_model(activation_function_type, n_classes)
    else:
        raise Exception("Currently we only support lenet or resnet50")
    
    return model
    

### Supported Post-Hoc OODN Processors

#### The first post processor we consider is ODIN

In [167]:
class ODINPostprocessor():
    def __init__(self, temperature, noise):
        self.temperature = temperature
        self.noise = noise
        
    def postprocess(self, net: nn.Module, data):
        net.eval()
        data.requires_grad = True
        output = net(data)

        # Calculating the perturbation we need to add, that is,
        # the sign of gradient of cross entropy loss w.r.t. input
        criterion = nn.CrossEntropyLoss()

        labels = output.detach().argmax(axis=1)

        # Using temperature scaling
        output = output / self.temperature

        loss = criterion(output, labels)
        loss.backward()

        # Normalizing the gradient to binary in {0, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2

        # Scaling values taken from original code       
        gradient[:, 0] = (gradient[:, 0]) / (63.0 / 255.0)
        if gradient.shape[1] == 3:
            gradient[:, 1] = (gradient[:, 1]) / (62.1 / 255.0)
            gradient[:, 2] = (gradient[:, 2]) / (66.7 / 255.0)

        # Adding small perturbations to images
        tempInputs = torch.add(data.detach(), gradient, alpha=-self.noise)
        output = net(tempInputs)
        output = output / self.temperature

        # Calculating the confidence after adding perturbations
        nnOutput = output.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)

        conf, pred = nnOutput.max(dim=1)

        return pred, conf
    
    def inference(self, net: nn.Module, data_loader: DataLoader):
        pred_list, conf_list, label_list = [], [], []
        for idx, loaded_data in enumerate(data_loader):
            data, label = loaded_data[0], loaded_data[1]
            if idx % 5 == 0:
                print(f'Performing inference on batch: {idx}')
            pred, conf = self.postprocess(net, data.to(device))
            for idx in range(len(data)):
                pred_list.append(pred[idx].tolist())
                conf_list.append(conf[idx].tolist())
                label_list.append(label[idx].tolist())

        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        conf_list = np.array(conf_list)
        label_list = np.array(label_list, dtype=int)

        return pred_list, conf_list, label_list

#### We consider the Maximum Classifier Discrepancy Post OODN method

https://arxiv.org/pdf/1712.02560.pdf

In [None]:
class MCDPostprocessor(BasePostprocessor):
    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        logits1, logits2 = net(data, return_double=True)
        score1 = torch.softmax(logits1, dim=1)
        score2 = torch.softmax(logits2, dim=1)
        conf = -torch.sum(torch.abs(score1 - score2), dim=1)
        _, pred = torch.max(score1, dim=1)
        return pred, conf

In [169]:
def get_postprocessor(postprocessor_type="odin"):
    if postprocessor_type == "odin":
        postprocessor = ODINPostprocessor(1000, 0.0014)
    return postprocessor

### Supported Out of Distribution Detection Metrics

What metrics do we specifically care about here?

**FPR@95** measures the false positive rate (FPR) when the true positive rate (TPR) is
equal to 95%. Lower scores indicate better performance. 

**AUROC** measures the area under the
Receiver Operating Characteristic (ROC) curve, which displays the relationship between TPR and
FPR. The area under the ROC curve can be interpreted as the probability that a positive ID example
will have a higher detection score than a negative OOD example. 

**AUPR** measures the area under
the Precision-Recall (PR) curve. The PR curve is created by plotting precision versus recall. Similar
to AUROC, we consider ID samples as positive, so that the score corresponds to the AUPR-In metric
in some works

In [170]:
def calculate_oodn_metrics(model, postprocessor_type, id_test_loader, ood_test_loader, ood_name):
    postprocessor = get_postprocessor(postprocessor_type)
    id_pred, id_conf, id_gt = postprocessor.inference(
                model, id_test_loader)

    ood_pred, ood_conf, ood_gt = postprocessor.inference(
        model, ood_test_loader)

    ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
    pred = np.concatenate([id_pred, ood_pred])
    conf = np.concatenate([id_conf, ood_conf])
    label = np.concatenate([id_gt, ood_gt])
    ood_metrics = metrics.compute_all_metrics(conf, label, pred)

    print_formatted_metrics(ood_metrics, ood_name)

def print_formatted_metrics(metrics, dataset_name):
    [fpr, auroc, aupr_in, aupr_out,
     ccr_4, ccr_3, ccr_2, ccr_1, accuracy] \
     = metrics

    write_content = {
        'dataset': dataset_name,
        'FPR@95': '{:.2f}'.format(100 * fpr),
        'AUROC': '{:.2f}'.format(100 * auroc),
        'AUPR_IN': '{:.2f}'.format(100 * aupr_in),
        'AUPR_OUT': '{:.2f}'.format(100 * aupr_out),
        'CCR_4': '{:.2f}'.format(100 * ccr_4),
        'CCR_3': '{:.2f}'.format(100 * ccr_3),
        'CCR_2': '{:.2f}'.format(100 * ccr_2),
        'CCR_1': '{:.2f}'.format(100 * ccr_1),
        'ACC': '{:.2f}'.format(100 * accuracy)
    }

    fieldnames = list(write_content.keys())

    # print ood metric results
    print('FPR@95: {:.2f}, AUROC: {:.2f}'.format(100 * fpr, 100 * auroc),
          end=' ',
          flush=True)
    print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(
        100 * aupr_in, 100 * aupr_out),
          flush=True)
    print('CCR: {:.2f}, {:.2f}, {:.2f}, {:.2f},'.format(
        ccr_4 * 100, ccr_3 * 100, ccr_2 * 100, ccr_1 * 100),
          end=' ',
          flush=True)
    print('ACC: {:.2f}'.format(accuracy * 100), flush=True)
    print(u'\u2500' * 70, flush=True)

In [191]:
def get_optimizer(model, config):
    params = model.parameters()
    lr = config['lr']
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer_type = config['optimizer_type']
    
    print(f'Getting optimizer for type: {optimizer_type}...')
    if optimizer_type == 'SGD':
        return SGD(params, 
              lr=lr, 
              momentum=momentum,
              weight_decay=weight_decay)
    elif optimizer_type == 'Adam':
        return Adam(params, 
                    lr=lr, 
                    weight_decay=weight_decay)
    else:
        raise Exception("Invalid optimizer_type provided, only SGD and Adam are supported currently")
        
def get_data_loaders(config):
    data_loaders = {}
    dataset_name = config["dataset_name"]
    dataset_type = config["dataset_type"]
    batch_size = config['batch_size']
    
    if dataset_type == "wilds":
        # wilds dataset
        dataset = get_dataset(dataset=dataset_name, download=True)
        data_loaders["train"] = get_wilds_loader(dataset, "train", batch_size)
        data_loaders["ood_test"] = get_wilds_loader(dataset, "test", batch_size)
        data_loaders["id_test"] = get_wilds_loader(dataset, "id_test", batch_size)
    elif dataset_name == "mnist":
        # mnist dataset
        train_dataset = mnist.MNIST(root='data', download=False, train=True, transform=ToTensor())
        test_dataset = mnist.MNIST(root='data', download=False, train=False, transform=ToTensor())
        fashion_test_dataset = mnist.FashionMNIST(root='data', download=True,train=False,transform=ToTensor())

        data_loaders["train"] = DataLoader(train_dataset, batch_size=batch_size)
        data_loaders["id_test"] = DataLoader(test_dataset, batch_size=batch_size)
        data_loaders["ood_test"] = DataLoader(fashion_test_dataset, batch_size=batch_size)
    
    return data_loaders

In [172]:
def train_resnet_model_given_opti_activation_fn(config):
    # get the train loader
    train_loader = config["data_loaders"]["train"]
    
    # get the resnet model with the replaced activation functions
    model = get_model(config)
    model.to(device)
    
    # get the optimizer
    sgd = get_optimizer(model, config)
    
    loss_fn = CrossEntropyLoss()

    for current_epoch in range(config['epochs']):
        model.train()
        print('Training epoch: {}'.format(current_epoch))
        for idx, (loader_data) in enumerate(train_loader):
            train_x, train_label = loader_data[0].to(device), loader_data[1].to(device)
            sgd.zero_grad()
            predict_y = model(train_x.float())
            loss = loss_fn(predict_y, train_label.long())
            if idx % 100 == 0:
                print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
            loss.backward()
            sgd.step()
    
    model_name = f"models/{config['network']}_{config['dataset_name']}_{config['version']}.pkl"
    torch.save(model, model_name)
    return model

def run_full_oodn_pipeline(config):
    # train model
    model = train_resnet_model_given_opti_activation_fn(config)
    # calculate oodn metrics
    calculate_oodn_metrics(model,
                           "odin", 
                           config["data_loaders"]["id_test"], 
                           config["data_loaders"]["ood_test"], 
                           config["dataset_name"])

### Study 1: LeNet5, MNIST as the ID Dataset, FashionMNIST as the OOD Dataset

#### Study 1(a): Combination: SGD + ReLU

In [195]:
config_1a = {
    "batch_size": 128,
    "n_classes": 10,
    "dataset_name": "mnist",
    "epochs": 10,
    "version": time.time(),
    "lr": 0.1,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "SGD",
    "activation_function_type": "relu",
    "network": "lenet",
    "postprocessor_type": "odin",
    "dataset_type": "mnist"
}
config_1a["data_loaders"] = get_data_loaders(config_1a)
run_full_oodn_pipeline(config_1a)

Getting optimizer for type: SGD...
Training epoch: 0
idx: 0, loss: 2.297888994216919
idx: 100, loss: 2.308368682861328
idx: 200, loss: 2.304520606994629
idx: 300, loss: 2.302668333053589
idx: 400, loss: 2.3048624992370605
Training epoch: 1
idx: 0, loss: 2.29297137260437
idx: 100, loss: 2.3060293197631836
idx: 200, loss: 2.3049848079681396
idx: 300, loss: 2.3015921115875244
idx: 400, loss: 2.3056366443634033
Training epoch: 2
idx: 0, loss: 2.293072462081909
idx: 100, loss: 2.306029796600342
idx: 200, loss: 2.3049848079681396
idx: 300, loss: 2.301593780517578
idx: 400, loss: 2.3056366443634033
Training epoch: 3
idx: 0, loss: 2.293072462081909
idx: 100, loss: 2.306029796600342
idx: 200, loss: 2.3049848079681396
idx: 300, loss: 2.3015921115875244
idx: 400, loss: 2.3056366443634033
Training epoch: 4
idx: 0, loss: 2.293072462081909
idx: 100, loss: 2.306029796600342
idx: 200, loss: 2.3049848079681396
idx: 300, loss: 2.3015921115875244
idx: 400, loss: 2.3056318759918213
Training epoch: 5
idx: 

#### 1 (b) SGD + SoftPlus

In [194]:
config_1b = {
    "batch_size": 128,
    "dataset_name": "mnist",
    "n_classes": 10,
    "epochs": 10,
    "version": time.time(),
    "lr": 0.1,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "SGD",
    "activation_function_type": "softplus",
    "network": "lenet",
    "postprocessor_type": "odin",
    "dataset_type": "mnist"
}
config_1b["data_loaders"] = get_data_loaders(config_1b)
run_full_oodn_pipeline(config_1b)

Getting optimizer for type: SGD...
Training epoch: 0
idx: 0, loss: 2.376993417739868
idx: 100, loss: 2.3054940700531006
idx: 200, loss: 1.3798028230667114
idx: 300, loss: 0.3979167938232422
idx: 400, loss: 0.3049400746822357
Training epoch: 1
idx: 0, loss: 0.09874945133924484
idx: 100, loss: 0.1565113067626953
idx: 200, loss: 0.16769079864025116
idx: 300, loss: 0.07754665613174438
idx: 400, loss: 0.22025398910045624
Training epoch: 2
idx: 0, loss: 0.12479684501886368
idx: 100, loss: 0.08266792446374893
idx: 200, loss: 0.12717781960964203
idx: 300, loss: 0.06364835053682327
idx: 400, loss: 0.16900688409805298
Training epoch: 3
idx: 0, loss: 0.09916403144598007
idx: 100, loss: 0.03581171855330467
idx: 200, loss: 0.14071375131607056
idx: 300, loss: 0.0668136402964592
idx: 400, loss: 0.15953058004379272
Training epoch: 4
idx: 0, loss: 0.10201045125722885
idx: 100, loss: 0.034900370985269547
idx: 200, loss: 0.1346389651298523
idx: 300, loss: 0.06171699985861778
idx: 400, loss: 0.11651628464

#### 1(c) Adam + ReLU

In [193]:
config_1c = {
    "batch_size": 128,
    "dataset_name": "mnist",
    "n_classes": 10,
    "epochs": 10,
    "version": time.time(),
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "Adam",
    "activation_function_type": "relu",
    "network": "lenet",
    "postprocessor_type": "odin",
    "dataset_type": "mnist"
}
config_1c["data_loaders"] = get_data_loaders(config_1c)
run_full_oodn_pipeline(config_1c)

Getting optimizer for type: Adam...
Training epoch: 0
idx: 0, loss: 2.2955210208892822
idx: 100, loss: 0.1766306608915329
idx: 200, loss: 0.1108548566699028
idx: 300, loss: 0.10212960094213486
idx: 400, loss: 0.3330700695514679
Training epoch: 1
idx: 0, loss: 0.08501432090997696
idx: 100, loss: 0.17185954749584198
idx: 200, loss: 0.09786579757928848
idx: 300, loss: 0.08116117119789124
idx: 400, loss: 0.22271095216274261
Training epoch: 2
idx: 0, loss: 0.05359487608075142
idx: 100, loss: 0.05140208080410957
idx: 200, loss: 0.06912437081336975
idx: 300, loss: 0.07847478985786438
idx: 400, loss: 0.18062497675418854
Training epoch: 3
idx: 0, loss: 0.076697438955307
idx: 100, loss: 0.10458298027515411
idx: 200, loss: 0.1426263004541397
idx: 300, loss: 0.11266268044710159
idx: 400, loss: 0.18627998232841492
Training epoch: 4
idx: 0, loss: 0.07869116961956024
idx: 100, loss: 0.10094181448221207
idx: 200, loss: 0.05744136869907379
idx: 300, loss: 0.06453729420900345
idx: 400, loss: 0.113194130

#### 1(d) Adam + Softplus

In [196]:
config_1d = {
    "batch_size": 128,
    "dataset_name": "mnist",
    "n_classes": 10,
    "epochs": 10,
    "version": time.time(),
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "Adam",
    "activation_function_type": "softplus",
    "network": "lenet",
    "postprocessor_type": "odin",
    "dataset_type": "mnist"
}
config_1d["data_loaders"] = get_data_loaders(config_1d)
run_full_oodn_pipeline(config_1d)

Getting optimizer for type: Adam...
Training epoch: 0
idx: 0, loss: 2.27415132522583
idx: 100, loss: 2.308807849884033
idx: 200, loss: 0.7072268724441528
idx: 300, loss: 0.18639841675758362
idx: 400, loss: 0.3177991807460785
Training epoch: 1
idx: 0, loss: 0.11948294937610626
idx: 100, loss: 0.06955642253160477
idx: 200, loss: 0.11207076907157898
idx: 300, loss: 0.07860124111175537
idx: 400, loss: 0.22918754816055298
Training epoch: 2
idx: 0, loss: 0.07632873207330704
idx: 100, loss: 0.05131511762738228
idx: 200, loss: 0.10534033179283142
idx: 300, loss: 0.08326361328363419
idx: 400, loss: 0.19460807740688324
Training epoch: 3
idx: 0, loss: 0.05745804309844971
idx: 100, loss: 0.07289870083332062
idx: 200, loss: 0.08870033919811249
idx: 300, loss: 0.1120220348238945
idx: 400, loss: 0.21049143373966217
Training epoch: 4
idx: 0, loss: 0.08756529539823532
idx: 100, loss: 0.06747358292341232
idx: 200, loss: 0.15172752737998962
idx: 300, loss: 0.09393009543418884
idx: 400, loss: 0.1624705791

### Study 2: Resnet50, WILDS iwildcam as the Dataset

In [None]:
config_2a = {
    "batch_size": 16,
    "n_classes": 2,
    "dataset_name": "camelyon17",
    "epochs": 10,
    "version": time.time(),
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "SGD",
    "activation_function_type": "softplus",
    "network": "resnet50",
    "postprocessor_type": "odin",
    "dataset_type": "wilds"
}
config_2a["data_loaders"] = get_data_loaders(config_2a)
# run_full_oodn_pipeline(config_1d)

Downloading dataset to data/camelyon17_v1.0...
You can also download the dataset manually at https://wilds.stanford.edu/downloads.
Downloading https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/ to data/camelyon17_v1.0/archive.tar.gz


  0%|          | 0/10658709504 [00:00<?, ?Byte/s]

In [None]:
run_full_oodn_pipeline(config_2a)