### Steps:

* Train LeNet5 model on mnist
* Use ODIN postprocessor on mnistood
* Get evaluation metrics
* Load iWildsCam dataset
* Trian on Resnet
* Use ODIN postprocessor

In [112]:
import time

import torch

from torchvision.datasets import mnist, FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.optim import SGD
from torch.nn import Module
from torch import nn
from torch.nn import CrossEntropyLoss
from torchvision.models.resnet import Bottleneck, ResNet

import numpy as np

from openood.evaluators import metrics

In [61]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


### Supported Activation Functions

For activation functions, we are considering ReLU, Softplus, Swish. *Note that we may conduct experiments for a subset based on the compute resources available*

In [4]:
def get_activation_fn(activation):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'softplus':
        return nn.Softplus()
    elif activation == 'swish':
        return nn.Swish()
    return None

### Supported Networks

Currently, we support LeNet and ResNet50.

In [5]:
class LeNet(nn.Module):
    def __init__(self, num_classes, num_channel=3, activation='relu'):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.feature_size = 84
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=num_channel,
                      out_channels=6,
                      kernel_size=5,
                      stride=1,
                      padding=2), get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
             get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=120,
                      kernel_size=5,
                      stride=1), get_activation_fn(activation))

        self.classifier1 = nn.Linear(in_features=120, out_features=84)
        self.relu = get_activation_fn(activation)
        self.fc = nn.Linear(in_features=84, out_features=num_classes)

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()

    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        logits_cls = self.fc(feature)
        feature_list = [feature1, feature2, feature3, feature]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        feature = feature.clip(max=threshold)
        logits_cls = self.fc(feature)

        return logits_cls

In [108]:
class ResNet50(ResNet):
    def __init__(self,
                 block=Bottleneck,
                 layers=[3, 4, 6, 3],
                 num_classes=1000):
        super(ResNet50, self).__init__(block=block,
                                       layers=layers,
                                       num_classes=num_classes)
        self.feature_size = 2048


    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.relu(self.bn1(self.conv1(x)))
        feature1 = self.maxpool(feature1)
        feature2 = self.layer1(feature1)
        feature3 = self.layer2(feature2)
        feature4 = self.layer3(feature3)
        feature5 = self.layer4(feature4)
        feature5 = self.avgpool(feature5)
        feature = feature5.view(feature5.size(0), -1)
        logits_cls = self.fc(feature)

        feature_list = [feature1, feature2, feature3, feature4, feature5]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.relu(self.bn1(self.conv1(x)))
        feature1 = self.maxpool(feature1)
        feature2 = self.layer1(feature1)
        feature3 = self.layer2(feature2)
        feature4 = self.layer3(feature3)
        feature5 = self.layer4(feature4)
        feature5 = self.avgpool(feature5)
        feature = feature5.clip(max=threshold)
        feature = feature.view(feature.size(0), -1)
        logits_cls = self.fc(feature)

        return logits_cls

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()
    
def get_resnet_model(activation_function_type):
    resnet_model = ResNet50(num_classes=dataset.n_classes)
    resnet_model.to(device)    
    return resnet_model

def set_activation_function(resnet_model, activation_function_type):
    resnet_model.relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[2].relu = get_activation_fn(activation_function_type)

    resnet_model.layer2[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[3].relu = get_activation_fn(activation_function_type)

    resnet_model.layer3[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[3].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[4].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[5].relu = get_activation_fn(activation_function_type)


    resnet_model.layer4[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[2].relu = get_activation_fn(activation_function_type)
    
    return resnet_model

In [134]:
def get_model(config):
    activation_function_type = config["activation_function_type"]
    network_type = config["network"]
    
    if network_type == "lenet":
        model =  LeNet(num_classes=10, num_channel=1, activation=activation_function_type)
    elif network_type == "resnet50":
        model = get_resnet_models(activation_function_type)
    else:
        raise Exception("Currently we only support lenet or resnet50")
    
    return model
    

### Supported Post-Hoc OODN Processors

#### The first post processor we consider is ODIN

In [145]:
class ODINPostprocessor():
    def __init__(self, temperature, noise):
        self.temperature = temperature
        self.noise = noise
        
    def postprocess(self, net: nn.Module, data):
        net.eval()
        data.requires_grad = True
        output = net(data)

        # Calculating the perturbation we need to add, that is,
        # the sign of gradient of cross entropy loss w.r.t. input
        criterion = nn.CrossEntropyLoss()

        labels = output.detach().argmax(axis=1)

        # Using temperature scaling
        output = output / self.temperature

        loss = criterion(output, labels)
        loss.backward()

        # Normalizing the gradient to binary in {0, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2

        # Scaling values taken from original code
        gradient[:, 0] = (gradient[:, 0]) / (63.0 / 255.0)
#         gradient[:, 1] = (gradient[:, 1]) / (62.1 / 255.0)
#         gradient[:, 2] = (gradient[:, 2]) / (66.7 / 255.0)

        # Adding small perturbations to images
        tempInputs = torch.add(data.detach(), gradient, alpha=-self.noise)
        output = net(tempInputs)
        output = output / self.temperature

        # Calculating the confidence after adding perturbations
        nnOutput = output.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)

        conf, pred = nnOutput.max(dim=1)

        return pred, conf
    
    def inference(self, net: nn.Module, data_loader: DataLoader):
        pred_list, conf_list, label_list = [], [], []
        for idx, loaded_data in enumerate(data_loader):
            data, label = loaded_data[0], loaded_data[1]
            if idx % 5 == 0:
                print(f'Performing inference on batch: {idx}')
            pred, conf = self.postprocess(net, data.to(device))
            for idx in range(len(data)):
                pred_list.append(pred[idx].tolist())
                conf_list.append(conf[idx].tolist())
                label_list.append(label[idx].tolist())

        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        conf_list = np.array(conf_list)
        label_list = np.array(label_list, dtype=int)

        return pred_list, conf_list, label_list

In [None]:
def get_postprocessor(postprocessor_type="odin"):
    if postprocessor_type == "odin":
        postprocessor = ODINPostprocessor(1000, 0.0014)
    return postprocessor

### Supported Out of Distribution Detection Metrics

What metrics do we specifically care about here?

**FPR@95** measures the false positive rate (FPR) when the true positive rate (TPR) is
equal to 95%. Lower scores indicate better performance. 

**AUROC** measures the area under the
Receiver Operating Characteristic (ROC) curve, which displays the relationship between TPR and
FPR. The area under the ROC curve can be interpreted as the probability that a positive ID example
will have a higher detection score than a negative OOD example. 

**AUPR** measures the area under
the Precision-Recall (PR) curve. The PR curve is created by plotting precision versus recall. Similar
to AUROC, we consider ID samples as positive, so that the score corresponds to the AUPR-In metric
in some works

In [109]:
def calculate_oodn_metrics(model, postprocessor_type, id_test_loader, ood_test_loader, ood_name):
    postprocessor = get_postprocessor(postprocessor_type)
    id_pred, id_conf, id_gt = postprocessor.inference(
                model, id_test_loader)

    ood_pred, ood_conf, ood_gt = postprocessor.inference(
        model, ood_test_loader)

    ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
    pred = np.concatenate([id_pred, ood_pred])
    conf = np.concatenate([id_conf, ood_conf])
    label = np.concatenate([id_gt, ood_gt])
    ood_metrics = metrics.compute_all_metrics(conf, label, pred)

    print_formatted_metrics(ood_metrics, ood_name)

def print_formatted_metrics(metrics, dataset_name):
    [fpr, auroc, aupr_in, aupr_out,
     ccr_4, ccr_3, ccr_2, ccr_1, accuracy] \
     = metrics

    write_content = {
        'dataset': dataset_name,
        'FPR@95': '{:.2f}'.format(100 * fpr),
        'AUROC': '{:.2f}'.format(100 * auroc),
        'AUPR_IN': '{:.2f}'.format(100 * aupr_in),
        'AUPR_OUT': '{:.2f}'.format(100 * aupr_out),
        'CCR_4': '{:.2f}'.format(100 * ccr_4),
        'CCR_3': '{:.2f}'.format(100 * ccr_3),
        'CCR_2': '{:.2f}'.format(100 * ccr_2),
        'CCR_1': '{:.2f}'.format(100 * ccr_1),
        'ACC': '{:.2f}'.format(100 * accuracy)
    }

    fieldnames = list(write_content.keys())

    # print ood metric results
    print('FPR@95: {:.2f}, AUROC: {:.2f}'.format(100 * fpr, 100 * auroc),
          end=' ',
          flush=True)
    print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(
        100 * aupr_in, 100 * aupr_out),
          flush=True)
    print('CCR: {:.2f}, {:.2f}, {:.2f}, {:.2f},'.format(
        ccr_4 * 100, ccr_3 * 100, ccr_2 * 100, ccr_1 * 100),
          end=' ',
          flush=True)
    print('ACC: {:.2f}'.format(accuracy * 100), flush=True)
    print(u'\u2500' * 70, flush=True)

### Full OODN Flow On LeNet5

### MNIST Training and In-Distribution Test Dataset

In [168]:
train_dataset = mnist.MNIST(root='data', download=False, train=True, transform=ToTensor())
test_dataset = mnist.MNIST(root='data', download=False, train=False, transform=ToTensor())

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

### Training LeNet5 On MNIST

In [153]:
model = LeNet(num_classes=10, num_channel=1, activation='softplus')
sgd = SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
loss_fn = CrossEntropyLoss()
all_epoch = 5

for current_epoch in range(all_epoch):
    model.train()
    for idx, (train_x, train_label) in enumerate(train_loader):
        sgd.zero_grad()
        predict_y = model(train_x.float())
        loss = loss_fn(predict_y, train_label.long())
        if idx % 100 == 0:
            print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
        loss.backward()
        sgd.step()

    all_correct_num = 0
    all_sample_num = 0
    model.eval()
    for idx, (test_x, test_label) in enumerate(test_loader):
        predict_y = model(test_x.float()).detach()
        predict_y = np.argmax(predict_y, axis=-1)
        current_correct_num = predict_y == test_label
        all_correct_num += np.sum(current_correct_num.numpy(), axis=-1)
        all_sample_num += current_correct_num.shape[0]
    acc = all_correct_num / all_sample_num
    print('accuracy: {:.2f}'.format(acc))
    
    if current_epoch % 25 == 0:
        torch.save(model, 'models/mnist_{:.2f}.pkl'.format(acc))

idx: 0, loss: 2.423340082168579
idx: 100, loss: 2.3033783435821533
idx: 200, loss: 2.309399366378784
idx: 300, loss: 2.059486150741577
idx: 400, loss: 0.45519348978996277
accuracy: 0.86
idx: 0, loss: 0.42277392745018005
idx: 100, loss: 0.31062746047973633
idx: 200, loss: 0.186587855219841
idx: 300, loss: 0.14288684725761414
idx: 400, loss: 0.527208149433136
accuracy: 0.95
idx: 0, loss: 0.15106496214866638
idx: 100, loss: 0.1727137267589569
idx: 200, loss: 0.152151957154274
idx: 300, loss: 0.11541634798049927
idx: 400, loss: 0.4183972179889679
accuracy: 0.96
idx: 0, loss: 0.1316715031862259
idx: 100, loss: 0.055190905928611755
idx: 200, loss: 0.1544199436903
idx: 300, loss: 0.08166653662919998
idx: 400, loss: 0.2765331268310547
accuracy: 0.98
idx: 0, loss: 0.09350304305553436
idx: 100, loss: 0.051850125193595886
idx: 200, loss: 0.16075889766216278
idx: 300, loss: 0.08697907626628876
idx: 400, loss: 0.23928365111351013
accuracy: 0.97


In [121]:
def get_optimizer(model, config):
    params = model.parameters()
    lr = config['lr']
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer_type = config['optimizer_type']
    
    print(f'Getting optimizer for type: {optimizer_type}...')
    if optimizer_type == 'SGD':
        return SGD(params, 
              lr=lr, 
              momentum=momentum,
              weight_decay=weight_decay)
    elif optimizer_type == 'Adam':
        return Adam(params, 
                    lr=lr, 
                    weight_decay=weight_decay)
    else:
        raise Exception("Invalid optimizer_type provided, only SGD and Adam are supported currently")
        
def get_data_loaders(config):
    data_loaders = {}
    dataset_name = config["dataset_name"]
    batch_size = config['batch_size']
    
    if dataset_name == "iwildcam":
        # wilds dataset
        dataset = get_dataset(dataset="iwildcam", download=False)
        data_loaders["train"] = get_wilds_loader(dataset, "train", batch_size)
        data_loaders["ood_test"] = get_wilds_loader(dataset, "test", batch_size)
        data_loaders["id_test"] = get_wilds_loader(dataset, "id_test", batch_size)
    
    elif dataset_name == "globalwheat":
        # wilds dataset
        dataset = get_dataset(dataset="globalwheat", download=False)
        data_loaders["train"] = get_wilds_loader(dataset, "train", batch_size)
        data_loaders["ood_test"] = get_wilds_loader(dataset, "test", batch_size)
        data_loaders["id_test"] = get_wilds_loader(dataset, "id_test", batch_size)
    
    elif dataset_name == "mnist":
        # mnist dataset
        train_dataset = mnist.MNIST(root='data', download=False, train=True, transform=ToTensor())
        test_dataset = mnist.MNIST(root='data', download=False, train=False, transform=ToTensor())
        fashion_test_dataset = mnist.FashionMNIST(root='data', download=True,train=False,transform=ToTensor())

        data_loaders["train"] = DataLoader(train_dataset, batch_size=batch_size)
        data_loaders["id_test"] = DataLoader(test_dataset, batch_size=batch_size)
        data_loaders["ood_test"] = DataLoader(fashion_test_dataset, batch_size=batch_size)
        
    return data_loaders

In [149]:
def train_resnet_model_given_opti_activation_fn(config):
    # get the train loader
    train_loader = config["data_loaders"]["train"]
    
    # get the resnet model with the replaced activation functions
    model = get_model(config)
    model.to(device)
    
    # get the optimizer
    sgd = get_optimizer(model, config)
    
    loss_fn = CrossEntropyLoss()

    for current_epoch in range(config['epochs']):
        model.train()
        print('Training epoch: {}'.format(current_epoch))
        for idx, (loader_data) in enumerate(train_loader):
            train_x, train_label = loader_data[0].to(device), loader_data[1].to(device)
            sgd.zero_grad()
            predict_y = model(train_x.float())
            loss = loss_fn(predict_y, train_label.long())
            if idx % 100 == 0:
                print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
            loss.backward()
            sgd.step()
    
    model_name = f"models/{config['network']}_{config['dataset_name']}_{config['version']}.pkl"
    torch.save(model, model_name)
    return model

def run_full_oodn_pipeline(config):
    # train model
    model = train_resnet_model_given_opti_activation_fn(config)
    # calculate oodn metrics
    calculate_oodn_metrics(model,
                           "odin", 
                           config["data_loaders"]["id_test"], 
                           config["data_loaders"]["ood_test"], 
                           config["dataset_name"])

In [152]:
config = {
    "batch_size": 128,
    "dataset_name": "mnist",
    "epochs": 10,
    "version": time.time(),
    "lr": 0.1,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "SGD",
    "activation_function_type": "relu",
    "network": "lenet"
}
config["data_loaders"] = get_data_loaders(config)

In [153]:
# train model
model = train_resnet_model_given_opti_activation_fn(config)

Getting optimizer for type: SGD...
Training epoch: 0
idx: 0, loss: 2.312520980834961
idx: 100, loss: 0.6028854846954346
idx: 200, loss: 0.19040091335773468
idx: 300, loss: 0.12135443836450577
idx: 400, loss: 0.35170498490333557
Training epoch: 1
idx: 0, loss: 0.1288277804851532
idx: 100, loss: 0.08615870028734207
idx: 200, loss: 0.09843085706233978
idx: 300, loss: 0.0776451826095581
idx: 400, loss: 0.14757603406906128
Training epoch: 2
idx: 0, loss: 0.12689444422721863
idx: 100, loss: 0.03204474598169327
idx: 200, loss: 0.10500169545412064
idx: 300, loss: 0.06263567507266998
idx: 400, loss: 0.19584523141384125
Training epoch: 3
idx: 0, loss: 0.04322836920619011
idx: 100, loss: 0.04996313899755478
idx: 200, loss: 0.08130590617656708
idx: 300, loss: 0.08349289745092392
idx: 400, loss: 0.15449418127536774
Training epoch: 4
idx: 0, loss: 0.05667595937848091
idx: 100, loss: 0.060950856655836105
idx: 200, loss: 0.10066015273332596
idx: 300, loss: 0.047552477568387985
idx: 400, loss: 0.168061

In [154]:
# calculate oodn metrics
calculate_oodn_metrics(model,
                       "odin", 
                       config["data_loaders"]["id_test"], 
                       config["data_loaders"]["ood_test"], 
                       config["dataset_name"])

Performing inference on batch: 0
Performing inference on batch: 5
Performing inference on batch: 10
Performing inference on batch: 15
Performing inference on batch: 20
Performing inference on batch: 25
Performing inference on batch: 30
Performing inference on batch: 35
Performing inference on batch: 40
Performing inference on batch: 45
Performing inference on batch: 50
Performing inference on batch: 55
Performing inference on batch: 60
Performing inference on batch: 65
Performing inference on batch: 70
Performing inference on batch: 75
Performing inference on batch: 0
Performing inference on batch: 5
Performing inference on batch: 10
Performing inference on batch: 15
Performing inference on batch: 20
Performing inference on batch: 25
Performing inference on batch: 30
Performing inference on batch: 35
Performing inference on batch: 40
Performing inference on batch: 45
Performing inference on batch: 50
Performing inference on batch: 55
Performing inference on batch: 60
Performing inferen