In [1]:
from os import listdir

import time
import json
import copy

import torch

from torchvision.datasets import mnist, FashionMNIST, CIFAR10, CIFAR100
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.optim import SGD, Adam
from torch.nn import Module
from torch import nn
from torch.nn import CrossEntropyLoss
from torchvision.models.resnet import Bottleneck, ResNet
from torchvision import datasets, models, transforms

from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

from openood.evaluators import metrics

In [2]:
print(torch.version.cuda)

10.2


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [4]:
!pwd

/home/rdr2143/oodn-final-project/OpenOOD-nndl


### Supported Activation Functions

For activation functions, we are considering ReLU, Softplus, Swish. *Note that we may conduct experiments for a subset based on the compute resources available*

In [5]:
def get_activation_fn(activation):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'softplus':
        return nn.Softplus()
    elif activation == 'swish':
        return nn.Swish()
    return None

### LeNet

In [6]:
class LeNet(nn.Module):
    def __init__(self, num_classes, num_channel=3, activation='relu'):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.feature_size = 84
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=num_channel,
                      out_channels=6,
                      kernel_size=5,
                      stride=1,
                      padding=2), get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
             get_activation_fn(activation), nn.MaxPool2d(kernel_size=2))

        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=120,
                      kernel_size=5,
                      stride=1), get_activation_fn(activation))

        self.classifier1 = nn.Linear(in_features=120, out_features=84)
        self.relu = get_activation_fn(activation)
        self.fc = nn.Linear(in_features=84, out_features=num_classes)

    def get_fc(self):
        fc = self.fc
        return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()

    def forward(self, x, return_feature=False, return_feature_list=False):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        logits_cls = self.fc(feature)
        feature_list = [feature1, feature2, feature3, feature]
        if return_feature:
            return logits_cls, feature
        elif return_feature_list:
            return logits_cls, feature_list
        else:
            return logits_cls

    def forward_threshold(self, x, threshold):
        feature1 = self.block1(x)
        feature2 = self.block2(feature1)
        feature3 = self.block3(feature2)
        feature3 = feature3.view(feature3.shape[0], -1)
        feature = self.relu(self.classifier1(feature3))
        feature = feature.clip(max=threshold)
        logits_cls = self.fc(feature)

        return logits_cls

### ResNet50

In [7]:
def set_parameter_requires_grad(model):
    for name,param in model.named_parameters():
        if not (name.startswith('layer4') or name.startswith('fc')):
            param.requires_grad = False

def get_resnet_model(activation_function_type, n_classes, use_pretrained=True):
    resnet_model = models.resnet50(pretrained=use_pretrained)
    
    # if we use pretrained, then freeze the corresponding layers
    if use_pretrained:
        set_parameter_requires_grad(resnet_model, feature_extract)

    set_activation_function(resnet_model,activation_function_type)
    num_ftrs = resnet_model.fc.in_features
    resnet_model.fc = nn.Linear(num_ftrs, n_classes)
    resnet_model.to(device)
    return resnet_model

def set_activation_function(resnet_model, activation_function_type):
    resnet_model.relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer1[2].relu = get_activation_fn(activation_function_type)

    resnet_model.layer2[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer2[3].relu = get_activation_fn(activation_function_type)

    resnet_model.layer3[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[2].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[3].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[4].relu = get_activation_fn(activation_function_type)
    resnet_model.layer3[5].relu = get_activation_fn(activation_function_type)


    resnet_model.layer4[0].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[1].relu = get_activation_fn(activation_function_type)
    resnet_model.layer4[2].relu = get_activation_fn(activation_function_type)

    return resnet_model

In [8]:
def get_model(config):
    activation_function_type = config["activation_function_type"]
    network_type = config["network"]
    n_classes = config["n_classes"]

    if network_type == "lenet":
        model =  LeNet(num_classes=n_classes, num_channel=1, activation=activation_function_type)
    elif network_type == "resnet50":
        model = get_resnet_model(activation_function_type, n_classes, config['pretrained'])
    else:
        raise Exception("Currently we only support lenet or resnet50")

    return model

### Supported Post-Hoc OODN Processors

In [9]:
class OODPostprocessor():
    
    def inference(self, net: nn.Module, data_loader: DataLoader):
        pred_list, conf_list, label_list = [], [], []
        for idx, loaded_data in enumerate(data_loader):
            data, label = loaded_data[0], loaded_data[1]
            if idx % 50 == 0:
                print(f'Performing inference on batch: {idx}')
            pred, conf = self.postprocess(net, data.to(device))
            for idx in range(len(data)):
                pred_list.append(pred[idx].tolist())
                conf_list.append(conf[idx].tolist())
                label_list.append(label[idx].tolist())

        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        conf_list = np.array(conf_list)
        label_list = np.array(label_list, dtype=int)

        return pred_list, conf_list, label_list

In [10]:
class ODINPostprocessor(OODPostprocessor):
    def __init__(self, temperature, noise):
        super(OODPostprocessor)
        self.temperature = temperature
        self.noise = noise
        
    def postprocess(self, net: nn.Module, data):
        net.eval()
        data.requires_grad = True
        output = net(data)

        # Calculating the perturbation we need to add, that is,
        # the sign of gradient of cross entropy loss w.r.t. input
        criterion = nn.CrossEntropyLoss()

        labels = output.detach().argmax(axis=1)

        # Using temperature scaling
        output = output / self.temperature

        loss = criterion(output, labels)
        loss.backward()

        # Normalizing the gradient to binary in {0, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2

        # Scaling values taken from original code       
        gradient[:, 0] = (gradient[:, 0]) / (63.0 / 255.0)
        if gradient.shape[1] == 3:
            gradient[:, 1] = (gradient[:, 1]) / (62.1 / 255.0)
            gradient[:, 2] = (gradient[:, 2]) / (66.7 / 255.0)

        # Adding small perturbations to images
        tempInputs = torch.add(data.detach(), gradient, alpha=-self.noise)
        output = net(tempInputs)
        output = output / self.temperature

        # Calculating the confidence after adding perturbations
        nnOutput = output.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)

        conf, pred = nnOutput.max(dim=1)

        return pred, conf

In [11]:
class MCDPostprocessor(OODPostprocessor):
    def __init__(self, samples: int = 30):
        super(OODPostprocessor)
        self.samples = samples  #: number :math:`N` of samples

    def postprocess(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        mode_switch = False
        if not model.training:
            mode_switch = True

            model.train()

            for mod in model.modules():
                # reset batch norm layers.
                # TODO: are there other layers?
                if isinstance(mod, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                    mod.train(False)

        results = None
        with torch.no_grad():
            for i in range(self.samples):
                output = model(x).softmax(dim=1)
                if results is None:
                    results = torch.zeros(size=output.shape).to(device)
                results += output
        results /= self.samples

        if mode_switch:
            model.eval()
        
        conf, pred = results.max(dim=1)

        return pred, conf

In [12]:
def get_postprocessor(postprocessor_type="odin"):
    if postprocessor_type == "odin":
        postprocessor = ODINPostprocessor(1000, 0.0014)
    elif postprocessor_type == "mcd":
        postprocessor = MCDPostprocessor(30)
    return postprocessor

### Supported Out of Distribution Detection Metrics

What metrics do we specifically care about here?

**FPR@95** measures the false positive rate (FPR) when the true positive rate (TPR) is
equal to 95%. Lower scores indicate better performance.

**AUROC** measures the area under the
Receiver Operating Characteristic (ROC) curve, which displays the relationship between TPR and
FPR. The area under the ROC curve can be interpreted as the probability that a positive ID example
will have a higher detection score than a negative OOD example.

**AUPR** measures the area under
the Precision-Recall (PR) curve. The PR curve is created by plotting precision versus recall. Similar
to AUROC, we consider ID samples as positive, so that the score corresponds to the AUPR-In metric
in some works

In [13]:
def calculate_oodn_metrics(model, postprocessor_type, id_test_loader, ood_test_loader, ood_name):
    postprocessor = get_postprocessor(postprocessor_type)
    id_pred, id_conf, id_gt = postprocessor.inference(
                model, id_test_loader)

    ood_pred, ood_conf, ood_gt = postprocessor.inference(
        model, ood_test_loader)

    ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
    pred = np.concatenate([id_pred, ood_pred])
    conf = np.concatenate([id_conf, ood_conf])
    label = np.concatenate([id_gt, ood_gt])
    ood_metrics = metrics.compute_all_metrics(conf, label, pred)

    return print_and_get_formatted_metrics(ood_metrics, ood_name)

def print_and_get_formatted_metrics(metrics, dataset_name):
    [fpr, auroc, aupr_in, aupr_out,
     ccr_4, ccr_3, ccr_2, ccr_1, accuracy] \
     = metrics

    write_content = {
        'dataset': dataset_name,
        'FPR@95': '{:.2f}'.format(100 * fpr),
        'AUROC': '{:.2f}'.format(100 * auroc),
        'AUPR_IN': '{:.2f}'.format(100 * aupr_in),
        'AUPR_OUT': '{:.2f}'.format(100 * aupr_out),
        'CCR_4': '{:.2f}'.format(100 * ccr_4),
        'CCR_3': '{:.2f}'.format(100 * ccr_3),
        'CCR_2': '{:.2f}'.format(100 * ccr_2),
        'CCR_1': '{:.2f}'.format(100 * ccr_1),
        'ACC': '{:.2f}'.format(100 * accuracy)
    }

    fieldnames = list(write_content.keys())

    # print ood metric results
    print('FPR@95: {:.2f}, AUROC: {:.2f}'.format(100 * fpr, 100 * auroc),
          end=' ',
          flush=True)
    print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(
        100 * aupr_in, 100 * aupr_out),
          flush=True)
    print('CCR: {:.2f}, {:.2f}, {:.2f}, {:.2f},'.format(
        ccr_4 * 100, ccr_3 * 100, ccr_2 * 100, ccr_1 * 100),
          end=' ',
          flush=True)
    print('ACC: {:.2f}'.format(accuracy * 100), flush=True)
    print(u'\u2500' * 70, flush=True)
    return write_content

def load_results_into_df(dir_path):
    res_files = [dir_path+each for each in listdir(dir_path)]
    all_results = []
    columns = ['optimizer_type', 'activation_function_type', 'postprocessor_type', 'trial', 'AUROC']
    for fp in res_files:
        f = open(fp)
        data = json.load(f)
        for trial, results in data.items():
            all_results.append([
                    results['optimizer_type'],
                    results['activation_function_type'],
                    results['postprocessor_type'],
                    trial,
                    float(results['AUROC'])
                ])
    df = pd.DataFrame(all_results, columns=columns)
    return df

In [25]:
def get_optimizer(model, config):
    params = model.parameters()
    lr = config['lr']
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer_type = config['optimizer_type']

    print(f'Getting optimizer for type: {optimizer_type}...')
    if optimizer_type == 'SGD':
        return SGD(params,
              lr=lr,
              momentum=momentum,
              weight_decay=weight_decay)
    elif optimizer_type == 'Adam':
        return Adam(params,
                    lr=lr)
    else:
        raise Exception("Invalid optimizer_type provided, only SGD and Adam are supported currently")

def get_wilds_loader(dataset, split, batch_size):
    d = dataset.get_subset(
        split,
        # frac=0.1,
        transform=transforms.Compose(
            [transforms.Resize((448, 448)), transforms.ToTensor()]
        ),
    )
    # Prepare the standard data loader
    return get_train_loader("standard", d, batch_size=batch_size, num_workers=4)

def get_data_loaders(config):
    data_loaders = {}
    dataset_name = config["dataset_name"]
    dataset_type = config["dataset_type"]
    batch_size = config['batch_size']

    wilds_id_test_split = "id_val" if dataset_name == "camelyon17" else "id_test"
    if dataset_type == "wilds":
        # wilds dataset
        dataset = get_dataset(dataset=dataset_name, download=True)
        data_loaders["train"] = get_wilds_loader(dataset, "train", batch_size)
        data_loaders["ood_test"] = get_wilds_loader(dataset, "test", batch_size)
        data_loaders["id_test"] = get_wilds_loader(dataset, wilds_id_test_split, batch_size)
        return
    elif dataset_name == "cifar":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root='data', download=True, train=True, transform=transform_train)
        test_dataset = CIFAR10(root='data', download=True, train=False, transform=transform_test)
        ood_test_dataset = CIFAR100(root='data', download=True, train=False, transform=transform_test)
        
    elif dataset_name == "mnist":
        # mnist dataset
        train_dataset = mnist.MNIST(root='data', download=True, train=True, transform=ToTensor())
        test_dataset = mnist.MNIST(root='data', download=True, train=False, transform=ToTensor())
        ood_test_dataset = mnist.FashionMNIST(root='data', download=True,train=False,transform=ToTensor())

    data_loaders["train"] = DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
    data_loaders["id_test"] = DataLoader(test_dataset, batch_size=batch_size, num_workers=4)
    data_loaders["ood_test"] = DataLoader(ood_test_dataset, batch_size=batch_size, num_workers=4)

    return data_loaders

In [26]:
def train_resnet_model_given_opti_activation_fn(config):
    # get the train loader
    train_loader = config["data_loaders"]["train"]

    # get the resnet model with the replaced activation functions
    model = get_model(config)
    model.to(device)

    # get the optimizer
    sgd = get_optimizer(model, config)

    loss_fn = CrossEntropyLoss()
    for current_epoch in range(config['epochs']):
        tic=time.time()
        per_batch_time = time.time()
        model.train()
        print('Training epoch: {}'.format(current_epoch))
        for idx, (loader_data) in enumerate(train_loader):
            train_x, train_label = loader_data[0].to(device), loader_data[1].to(device)
            sgd.zero_grad()
            predict_y = model(train_x.float())
            loss = loss_fn(predict_y, train_label.long())
            if idx % 100 == 0:
                print('idx: {}, loss: {} time take: {}'.format(idx, loss.sum().item(), time.time() - per_batch_time))
                per_batch_time = time.time()
            loss.backward()
            sgd.step()
        print(f"epoch {current_epoch} time taken: {time.time()-tic}s")
    torch.save(model, config['model_name'])

    return model

def run_full_oodn_pipeline(config):
    metrics = {}
    for i in range(config["trials"]):
        model_name = f"models/{config['dataset_name']}_{config['network']}_{config['postprocessor_type']}_{config['activation_function_type']}_{config['optimizer_type']}_{i}.pkl"
        print(f'Running model: {model_name}...')
        config['model_name'] = model_name
        # train model
        model = train_resnet_model_given_opti_activation_fn(config)
        # calculate oodn metrics
        metrics[i] = calculate_oodn_metrics(model,
                               config['postprocessor_type'],
                               config["data_loaders"]["id_test"],
                               config["data_loaders"]["ood_test"],
                               config["dataset_name"])
        metrics[i]['optimizer_type'] = config['optimizer_type']
        metrics[i]['activation_function_type'] = config['activation_function_type']
        metrics[i]['postprocessor_type'] = config['postprocessor_type']

    experiment_name = f"{config['results_dir']}/{config['dataset_name']}_{config['network']}_{config['postprocessor_type']}_{config['activation_function_type']}_{config['optimizer_type']}.json"
    with open(experiment_name, 'w') as fp:
        json.dump(metrics, fp)
    return metrics

### Study 2: Resnet, CIFAR-10 (ID), CIFAR-100 (OOD)

#### Study 2(a): Adam + ReLU + Odin

In [29]:
config_cifar_adam_relu_odin = {
    "batch_size": 128,
    "n_classes": 10,
    "dataset_name": "cifar",
    "epochs": 100,
    "version": time.time(),
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "Adam",
    "activation_function_type": "relu",
    "network": "resnet50",
    "postprocessor_type": "odin",
    "trials": 3,
    "dataset_type": "cifar",
    "results_dir": "cifar10-study",
    "pretrained": False
} 
config_cifar_adam_relu_odin["data_loaders"] = get_data_loaders(config_cifar_adam_relu_odin)
run_full_oodn_pipeline(config_cifar_adam_relu_odin)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Running model: models/cifar_resnet50_odin_relu_Adam_0.pkl...
Getting optimizer for type: Adam...
Training epoch: 0
idx: 0, loss: 2.429973840713501 time take: 0.2509191036224365
idx: 100, loss: 3.1445815563201904 time take: 6.605843782424927
idx: 200, loss: 3.0320160388946533 time take: 6.58167576789856
idx: 300, loss: 2.0357301235198975 time take: 6.582582235336304
epoch 0 time taken: 25.993609189987183s
Training epoch: 1
idx: 0, loss: 2.0158276557922363 time take: 0.24560880661010742
idx: 100, loss: 1.7750670909881592 time take: 6.587856769561768
idx: 200, loss: 1.7678837776184082 time take: 6.583316087722778
idx: 300, loss: 1.6960535049438477 time take: 6.652765512466431
epoch 1 time taken: 26.10966920852661s
Training epoch: 2
idx: 0, loss: 1.7007710933685303 time take: 0.45514822006225586
idx: 100, loss: 1.3905045986175537 time take: 6.6520912647247314
idx: 200, loss: 1.

idx: 200, loss: 0.597453773021698 time take: 6.604338645935059
idx: 300, loss: 0.5613834261894226 time take: 6.59797215461731
epoch 25 time taken: 26.084548950195312s
Training epoch: 26
idx: 0, loss: 0.46674519777297974 time take: 0.2747359275817871
idx: 100, loss: 0.45073235034942627 time take: 6.615034341812134
idx: 200, loss: 0.5573427677154541 time take: 6.605945348739624
idx: 300, loss: 0.5646507740020752 time take: 6.600234031677246
epoch 26 time taken: 26.07901382446289s
Training epoch: 27
idx: 0, loss: 0.40483447909355164 time take: 0.2372136116027832
idx: 100, loss: 0.4371560513973236 time take: 6.6735615730285645
idx: 200, loss: 0.6700161099433899 time take: 6.6589274406433105
idx: 300, loss: 0.5210594534873962 time take: 6.6518940925598145
epoch 27 time taken: 26.20836853981018s
Training epoch: 28
idx: 0, loss: 0.5300092101097107 time take: 0.2412855625152588
idx: 100, loss: 0.43567365407943726 time take: 6.72049355506897
idx: 200, loss: 0.5421472787857056 time take: 6.73247

idx: 200, loss: 0.3298856019973755 time take: 6.60926628112793
idx: 300, loss: 0.261945515871048 time take: 6.671645641326904
epoch 51 time taken: 26.160161018371582s
Training epoch: 52
idx: 0, loss: 0.21253448724746704 time take: 0.25517916679382324
idx: 100, loss: 0.2711976170539856 time take: 6.623281955718994
idx: 200, loss: 0.38510289788246155 time take: 6.668866395950317
idx: 300, loss: 0.34865692257881165 time take: 6.691893815994263
epoch 52 time taken: 26.37187385559082s
Training epoch: 53
idx: 0, loss: 0.18931463360786438 time take: 0.2438039779663086
idx: 100, loss: 0.1994004100561142 time take: 6.631709575653076
idx: 200, loss: 0.27953147888183594 time take: 6.607300043106079
idx: 300, loss: 0.30204665660858154 time take: 6.608463287353516
epoch 53 time taken: 26.0798819065094s
Training epoch: 54
idx: 0, loss: 0.24524739384651184 time take: 0.24196624755859375
idx: 100, loss: 0.18782228231430054 time take: 6.711690664291382
idx: 200, loss: 0.371204674243927 time take: 6.689

idx: 200, loss: 0.22645197808742523 time take: 6.610750436782837
idx: 300, loss: 0.20446710288524628 time take: 6.6332597732543945
epoch 77 time taken: 26.188644409179688s
Training epoch: 78
idx: 0, loss: 0.13726526498794556 time take: 0.2570681571960449
idx: 100, loss: 0.08217354863882065 time take: 6.6184821128845215
idx: 200, loss: 0.2705938220024109 time take: 6.608604907989502
idx: 300, loss: 0.19095657765865326 time take: 6.6150596141815186
epoch 78 time taken: 26.09790062904358s
Training epoch: 79
idx: 0, loss: 0.10920466482639313 time take: 0.2664153575897217
idx: 100, loss: 0.2774288058280945 time take: 6.624781608581543
idx: 200, loss: 0.23530609905719757 time take: 6.606082916259766
idx: 300, loss: 0.1965378075838089 time take: 6.607937812805176
epoch 79 time taken: 26.16877269744873s
Training epoch: 80
idx: 0, loss: 0.2221136838197708 time take: 0.23807144165039062
idx: 100, loss: 0.1357584148645401 time take: 6.643655300140381
idx: 200, loss: 0.21189424395561218 time take:

idx: 0, loss: 1.5828386545181274 time take: 0.25194764137268066
idx: 100, loss: 1.444257378578186 time take: 6.684164524078369
idx: 200, loss: 1.5520122051239014 time take: 6.586169958114624
idx: 300, loss: 1.530279278755188 time take: 6.585282802581787
epoch 2 time taken: 26.07736825942993s
Training epoch: 3
idx: 0, loss: 1.42722749710083 time take: 0.2662222385406494
idx: 100, loss: 1.3462961912155151 time take: 6.595289707183838
idx: 200, loss: 1.4635869264602661 time take: 6.589477777481079
idx: 300, loss: 1.854456901550293 time take: 6.588951110839844
epoch 3 time taken: 26.007307529449463s
Training epoch: 4
idx: 0, loss: 1.487541913986206 time take: 0.27080416679382324
idx: 100, loss: 1.3574988842010498 time take: 6.5935564041137695
idx: 200, loss: 1.331739902496338 time take: 6.638465642929077
idx: 300, loss: 1.2641595602035522 time take: 6.650737047195435
epoch 4 time taken: 26.124001502990723s
Training epoch: 5
idx: 0, loss: 1.29438316822052 time take: 0.23894691467285156
idx:

idx: 100, loss: 0.41825008392333984 time take: 6.608471632003784
idx: 200, loss: 0.6130558848381042 time take: 6.5978217124938965
idx: 300, loss: 0.5527329444885254 time take: 6.604417085647583
epoch 28 time taken: 26.013185024261475s
Training epoch: 29
idx: 0, loss: 0.44136032462120056 time take: 0.2511000633239746
idx: 100, loss: 0.45198312401771545 time take: 6.611998081207275
idx: 200, loss: 0.5535292625427246 time take: 6.601720809936523
idx: 300, loss: 0.48668721318244934 time take: 6.614667177200317
epoch 29 time taken: 26.20002245903015s
Training epoch: 30
idx: 0, loss: 0.44375431537628174 time take: 0.22037196159362793
idx: 100, loss: 0.4258122444152832 time take: 6.620441198348999
idx: 200, loss: 0.5633814334869385 time take: 6.603837013244629
idx: 300, loss: 0.5744293332099915 time take: 6.60388970375061
epoch 30 time taken: 26.033058404922485s
Training epoch: 31
idx: 0, loss: 0.38152432441711426 time take: 0.25177931785583496
idx: 100, loss: 0.467026025056839 time take: 6.7

idx: 100, loss: 0.1937101036310196 time take: 6.616108417510986
idx: 200, loss: 0.3188554644584656 time take: 6.603422164916992
idx: 300, loss: 0.450793981552124 time take: 6.60664701461792
epoch 54 time taken: 26.077478885650635s
Training epoch: 55
idx: 0, loss: 0.18588122725486755 time take: 0.25312280654907227
idx: 100, loss: 0.28052905201911926 time take: 6.671067476272583
idx: 200, loss: 0.6574788689613342 time take: 6.61028265953064
idx: 300, loss: 0.41383013129234314 time take: 6.603945970535278
epoch 55 time taken: 26.127240896224976s
Training epoch: 56
idx: 0, loss: 0.36356326937675476 time take: 0.2443833351135254
idx: 100, loss: 0.27254390716552734 time take: 6.695274353027344
idx: 200, loss: 0.4674862027168274 time take: 6.730895280838013
idx: 300, loss: 0.2915031611919403 time take: 6.628110647201538
epoch 56 time taken: 26.285008192062378s
Training epoch: 57
idx: 0, loss: 0.28718307614326477 time take: 0.2659947872161865
idx: 100, loss: 0.2894574999809265 time take: 6.645

idx: 100, loss: 0.15891627967357635 time take: 6.612596750259399
idx: 200, loss: 0.22067543864250183 time take: 6.641980171203613
idx: 300, loss: 0.30777692794799805 time take: 6.608111381530762
epoch 80 time taken: 26.117828607559204s
Training epoch: 81
idx: 0, loss: 0.2319861799478531 time take: 0.23457717895507812
idx: 100, loss: 0.15951944887638092 time take: 6.638277530670166
idx: 200, loss: 0.20585092902183533 time take: 6.654615640640259
idx: 300, loss: 0.20347203314304352 time take: 6.778892278671265
epoch 81 time taken: 26.372536182403564s
Training epoch: 82
idx: 0, loss: 0.19607262313365936 time take: 0.26544857025146484
idx: 100, loss: 0.1327524334192276 time take: 6.623323202133179
idx: 200, loss: 0.21612513065338135 time take: 6.649046182632446
idx: 300, loss: 0.23467859625816345 time take: 6.734400749206543
epoch 82 time taken: 26.317383527755737s
Training epoch: 83
idx: 0, loss: 0.20301899313926697 time take: 0.23670077323913574
idx: 100, loss: 0.18460744619369507 time t

epoch 4 time taken: 26.045965909957886s
Training epoch: 5
idx: 0, loss: 1.418537974357605 time take: 0.25160694122314453
idx: 100, loss: 1.340538501739502 time take: 6.595491170883179
idx: 200, loss: 1.34032142162323 time take: 6.584538459777832
idx: 300, loss: 1.3569982051849365 time take: 6.582293510437012
epoch 5 time taken: 25.981970071792603s
Training epoch: 6
idx: 0, loss: 1.370997667312622 time take: 0.2662355899810791
idx: 100, loss: 1.27835214138031 time take: 6.605718374252319
idx: 200, loss: 1.3348937034606934 time take: 6.582807302474976
idx: 300, loss: 1.2628141641616821 time take: 6.581418037414551
epoch 6 time taken: 26.008046627044678s
Training epoch: 7
idx: 0, loss: 1.2738956212997437 time take: 0.26305389404296875
idx: 100, loss: 1.1675077676773071 time take: 6.639991521835327
idx: 200, loss: 1.228621482849121 time take: 6.5817954540252686
idx: 300, loss: 1.1231369972229004 time take: 6.5856122970581055
epoch 7 time taken: 26.040780067443848s
Training epoch: 8
idx: 0,

idx: 0, loss: 0.3683387339115143 time take: 0.260467529296875
idx: 100, loss: 0.4372905492782593 time take: 6.616334438323975
idx: 200, loss: 0.5869937539100647 time take: 6.593849420547485
idx: 300, loss: 0.5844462513923645 time take: 6.597427606582642
epoch 31 time taken: 26.043166637420654s
Training epoch: 32
idx: 0, loss: 0.5211504697799683 time take: 0.23636245727539062
idx: 100, loss: 0.4015485644340515 time take: 6.608212232589722
idx: 200, loss: 0.5532274842262268 time take: 6.615742444992065
idx: 300, loss: 0.41155144572257996 time take: 6.595170259475708
epoch 32 time taken: 26.038853406906128s
Training epoch: 33
idx: 0, loss: 0.3876838684082031 time take: 0.24454259872436523
idx: 100, loss: 0.406494140625 time take: 6.612250566482544
idx: 200, loss: 0.5751639604568481 time take: 6.598246097564697
idx: 300, loss: 0.4406026601791382 time take: 6.598705053329468
epoch 33 time taken: 26.092898845672607s
Training epoch: 34
idx: 0, loss: 0.3477293848991394 time take: 0.26321244239

idx: 0, loss: 0.18580280244350433 time take: 0.2576570510864258
idx: 100, loss: 0.28699538111686707 time take: 6.617734909057617
idx: 200, loss: 0.3588714599609375 time take: 6.610241889953613
idx: 300, loss: 0.3063720166683197 time take: 6.660557985305786
epoch 57 time taken: 26.143787384033203s
Training epoch: 58
idx: 0, loss: 0.19715920090675354 time take: 0.25107359886169434
idx: 100, loss: 0.21232835948467255 time take: 6.619342565536499
idx: 200, loss: 0.3535938262939453 time take: 6.61928129196167
idx: 300, loss: 0.28335851430892944 time take: 6.613996267318726
epoch 58 time taken: 26.09619450569153s
Training epoch: 59
idx: 0, loss: 0.25357508659362793 time take: 0.25415706634521484
idx: 100, loss: 0.16154161095619202 time take: 6.747373580932617
idx: 200, loss: 0.36901170015335083 time take: 6.7278053760528564
idx: 300, loss: 0.170962855219841 time take: 6.615602254867554
epoch 59 time taken: 26.36773443222046s
Training epoch: 60
idx: 0, loss: 0.22231513261795044 time take: 0.2

idx: 0, loss: 0.17460007965564728 time take: 0.25932788848876953
idx: 100, loss: 0.13622137904167175 time take: 6.632730960845947
idx: 200, loss: 0.18141227960586548 time take: 6.61474871635437
idx: 300, loss: 0.38020291924476624 time take: 6.610538005828857
epoch 83 time taken: 26.12343430519104s
Training epoch: 84
idx: 0, loss: 0.22359351813793182 time take: 0.2522265911102295
idx: 100, loss: 0.10664388537406921 time take: 6.626541376113892
idx: 200, loss: 0.23136796057224274 time take: 6.726781606674194
idx: 300, loss: 0.20070664584636688 time take: 6.621631622314453
epoch 84 time taken: 26.217286586761475s
Training epoch: 85
idx: 0, loss: 0.08223766833543777 time take: 0.2227795124053955
idx: 100, loss: 0.09068890661001205 time take: 6.72458291053772
idx: 200, loss: 0.17354312539100647 time take: 6.652517318725586
idx: 300, loss: 0.20919522643089294 time take: 6.623441934585571
epoch 85 time taken: 26.27639412879944s
Training epoch: 86
idx: 0, loss: 0.23266887664794922 time take: 0

{0: {'dataset': 'cifar',
  'FPR@95': '74.84',
  'AUROC': '82.09',
  'AUPR_IN': '83.53',
  'AUPR_OUT': '78.47',
  'CCR_4': '0.15',
  'CCR_3': '3.41',
  'CCR_2': '20.23',
  'CCR_1': '55.44',
  'ACC': '85.13',
  'optimizer_type': 'Adam',
  'activation_function_type': 'relu',
  'postprocessor_type': 'odin'},
 1: {'dataset': 'cifar',
  'FPR@95': '79.59',
  'AUROC': '77.06',
  'AUPR_IN': '76.79',
  'AUPR_OUT': '73.75',
  'CCR_4': '0.01',
  'CCR_3': '0.02',
  'CCR_2': '9.41',
  'CCR_1': '41.89',
  'ACC': '75.03',
  'optimizer_type': 'Adam',
  'activation_function_type': 'relu',
  'postprocessor_type': 'odin'},
 2: {'dataset': 'cifar',
  'FPR@95': '74.17',
  'AUROC': '81.73',
  'AUPR_IN': '83.12',
  'AUPR_OUT': '78.55',
  'CCR_4': '2.71',
  'CCR_3': '4.16',
  'CCR_2': '18.66',
  'CCR_1': '54.58',
  'ACC': '84.94',
  'optimizer_type': 'Adam',
  'activation_function_type': 'relu',
  'postprocessor_type': 'odin'}}

#### Study 2 (b.) Adam + Softplus + odin

In [None]:
config_cifar_adam_softplus_odin = {
    "batch_size": 128,
    "n_classes": 10,
    "dataset_name": "cifar",
    "epochs": 100,
    "version": time.time(),
    "lr": 0.01,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "optimizer_type": "Adam",
    "activation_function_type": "softplus",
    "network": "resnet50",
    "postprocessor_type": "odin",
    "trials": 3,
    "dataset_type": "cifar",
    "results_dir": "cifar10-study",
    "pretrained": False
}
config_cifar_adam_softplus_odin["data_loaders"] = get_data_loaders(config_cifar_adam_softplus_odin)
run_full_oodn_pipeline(config_cifar_adam_softplus_odin)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Running model: models/cifar_resnet50_odin_softplus_Adam_0.pkl...
Getting optimizer for type: Adam...
Training epoch: 0
idx: 0, loss: 2.5337820053100586 time take: 0.2714509963989258
idx: 100, loss: 3.479654312133789 time take: 6.675965785980225
idx: 200, loss: 3.0372395515441895 time take: 6.696977138519287
idx: 300, loss: 2.558212995529175 time take: 6.654980897903442
epoch 0 time taken: 26.292356967926025s
Training epoch: 1
idx: 0, loss: 3.0308470726013184 time take: 0.24722886085510254
idx: 100, loss: 2.7717363834381104 time take: 6.645073890686035
idx: 200, loss: 2.4414756298065186 time take: 6.613676309585571
idx: 300, loss: 2.1312203407287598 time take: 6.613912582397461
epoch 1 time taken: 26.11988925933838s
Training epoch: 2
idx: 0, loss: 1.9130369424819946 time take: 0.23324036598205566
idx: 100, loss: 2.139862298965454 time take: 6.635451078414917
idx: 200, loss: 

idx: 200, loss: 0.6064569354057312 time take: 6.618139028549194
idx: 300, loss: 0.495060533285141 time take: 6.640892028808594
epoch 25 time taken: 26.242170333862305s
Training epoch: 26
idx: 0, loss: 0.441956490278244 time take: 0.25998806953430176
idx: 100, loss: 0.511660635471344 time take: 6.6779091358184814
idx: 200, loss: 0.6620965003967285 time take: 6.609328985214233
idx: 300, loss: 0.48995476961135864 time take: 6.6252381801605225
epoch 26 time taken: 26.173784255981445s
Training epoch: 27
idx: 0, loss: 0.4562523066997528 time take: 0.244065523147583
idx: 100, loss: 0.3325934112071991 time take: 6.6302971839904785
idx: 200, loss: 0.6206389665603638 time take: 6.615665674209595
idx: 300, loss: 0.5455806255340576 time take: 6.609037637710571
epoch 27 time taken: 26.11711287498474s
Training epoch: 28
idx: 0, loss: 0.5060380101203918 time take: 0.30780720710754395
idx: 100, loss: 0.44938430190086365 time take: 16.358985662460327
idx: 200, loss: 0.5723460912704468 time take: 13.471

idx: 200, loss: 0.31630343198776245 time take: 6.704124689102173
idx: 300, loss: 0.2205219268798828 time take: 6.7269980907440186
epoch 51 time taken: 26.419057607650757s
Training epoch: 52
idx: 0, loss: 0.19029900431632996 time take: 0.22911357879638672
idx: 100, loss: 0.14726689457893372 time take: 6.63547158241272
idx: 200, loss: 0.23855534195899963 time take: 6.621695280075073
idx: 300, loss: 0.21549226343631744 time take: 6.650598526000977
epoch 52 time taken: 26.155381441116333s
Training epoch: 53
idx: 0, loss: 0.18259534239768982 time take: 0.26244068145751953
idx: 100, loss: 0.1870933622121811 time take: 6.681147813796997
idx: 200, loss: 0.27151888608932495 time take: 6.625962018966675
idx: 300, loss: 0.24140478670597076 time take: 6.613759517669678
epoch 53 time taken: 26.179369688034058s
Training epoch: 54
idx: 0, loss: 0.14228695631027222 time take: 0.24442529678344727
idx: 100, loss: 0.25541624426841736 time take: 6.639346122741699
idx: 200, loss: 0.28225845098495483 time t

idx: 100, loss: 0.13588443398475647 time take: 6.622119426727295
idx: 200, loss: 0.2570956349372864 time take: 6.611584424972534
idx: 300, loss: 0.14335092902183533 time take: 6.614312410354614
epoch 77 time taken: 26.122233629226685s
Training epoch: 78
idx: 0, loss: 0.09631645679473877 time take: 0.24411225318908691
idx: 100, loss: 0.06118485704064369 time take: 6.6219470500946045
idx: 200, loss: 0.21508173644542694 time take: 6.6093573570251465
idx: 300, loss: 0.19602982699871063 time take: 6.613054513931274
epoch 78 time taken: 26.08490252494812s
Training epoch: 79
idx: 0, loss: 0.11522193998098373 time take: 0.24671530723571777
idx: 100, loss: 0.07653247565031052 time take: 6.651439666748047
idx: 200, loss: 0.12243413180112839 time take: 6.61201548576355
idx: 300, loss: 0.09151208400726318 time take: 6.61209511756897
epoch 79 time taken: 26.11755895614624s
Training epoch: 80
idx: 0, loss: 0.16937515139579773 time take: 0.2577834129333496
idx: 100, loss: 0.14408943057060242 time tak

epoch 1 time taken: 37.970566749572754s
Training epoch: 2
idx: 0, loss: 2.0488224029541016 time take: 0.4266514778137207
idx: 100, loss: 1.9244788885116577 time take: 11.499085664749146
idx: 200, loss: 1.9873977899551392 time take: 16.24790120124817
idx: 300, loss: 1.9369984865188599 time take: 16.11076807975769
epoch 2 time taken: 60.09878134727478s
Training epoch: 3
idx: 0, loss: 2.0564136505126953 time take: 0.3454399108886719
idx: 100, loss: 1.6248329877853394 time take: 17.435948610305786
idx: 200, loss: 1.623598337173462 time take: 17.435144424438477
idx: 300, loss: 1.790894865989685 time take: 17.470246076583862
epoch 3 time taken: 68.08431649208069s
Training epoch: 4
idx: 0, loss: 1.759663701057434 time take: 0.3084444999694824
idx: 100, loss: 1.5608335733413696 time take: 17.462621688842773
idx: 200, loss: 1.425879955291748 time take: 17.448249101638794
idx: 300, loss: 1.6865845918655396 time take: 17.416495084762573
epoch 4 time taken: 68.46527147293091s
Training epoch: 5
idx

idx: 0, loss: 0.4421517252922058 time take: 0.24152827262878418
idx: 100, loss: 0.4549048840999603 time take: 6.629119873046875
idx: 200, loss: 0.6421298980712891 time take: 6.614649057388306
idx: 300, loss: 0.573917031288147 time take: 6.614962339401245
epoch 28 time taken: 26.095964908599854s
Training epoch: 29
idx: 0, loss: 0.5052249431610107 time take: 0.24231767654418945
idx: 100, loss: 0.4861232042312622 time take: 6.627577066421509
idx: 200, loss: 0.6224702000617981 time take: 6.616735219955444
idx: 300, loss: 0.5747985243797302 time take: 6.614736080169678
epoch 29 time taken: 26.095654010772705s
Training epoch: 30
idx: 0, loss: 0.5002521872520447 time take: 0.2434849739074707
idx: 100, loss: 0.5111833810806274 time take: 6.691584825515747
idx: 200, loss: 0.6427602171897888 time take: 6.6386449337005615
idx: 300, loss: 0.42303481698036194 time take: 6.61634087562561
epoch 30 time taken: 26.188615560531616s
Training epoch: 31
idx: 0, loss: 0.4584808051586151 time take: 0.2503712

idx: 0, loss: 0.33583879470825195 time take: 0.2572667598724365
idx: 100, loss: 0.271762877702713 time take: 6.617304801940918
idx: 200, loss: 0.3425196707248688 time take: 6.668344497680664
idx: 300, loss: 0.23894907534122467 time take: 6.619114637374878
epoch 54 time taken: 26.15724468231201s
Training epoch: 55
idx: 0, loss: 0.20688733458518982 time take: 0.24809861183166504
idx: 100, loss: 0.21818581223487854 time take: 6.621826171875
idx: 200, loss: 0.2712675631046295 time take: 6.693837642669678
idx: 300, loss: 0.27271249890327454 time take: 6.656576633453369
epoch 55 time taken: 26.21820569038391s
Training epoch: 56
idx: 0, loss: 0.21178734302520752 time take: 0.23725652694702148
idx: 100, loss: 0.11650153994560242 time take: 6.624606132507324
idx: 200, loss: 0.3136546313762665 time take: 6.612377405166626
idx: 300, loss: 0.1502830982208252 time take: 6.614327907562256
epoch 56 time taken: 26.08275008201599s
Training epoch: 57
idx: 0, loss: 0.24572600424289703 time take: 0.263851

epoch 79 time taken: 38.95448398590088s
Training epoch: 80
idx: 0, loss: 0.1085602194070816 time take: 0.32114720344543457
idx: 100, loss: 0.1230112761259079 time take: 15.786655902862549
idx: 200, loss: 0.21541698276996613 time take: 17.449607849121094
idx: 300, loss: 0.19706372916698456 time take: 17.468252897262573
epoch 80 time taken: 66.85022187232971s
Training epoch: 81
idx: 0, loss: 0.2551638185977936 time take: 0.30420899391174316
idx: 100, loss: 0.07676795870065689 time take: 17.473079681396484
idx: 200, loss: 0.2054833173751831 time take: 17.051029205322266
idx: 300, loss: 0.15368445217609406 time take: 17.434929370880127
epoch 81 time taken: 68.12181353569031s
Training epoch: 82
idx: 0, loss: 0.13065050542354584 time take: 0.32121944427490234
idx: 100, loss: 0.15208721160888672 time take: 17.441019773483276
idx: 200, loss: 0.10273505002260208 time take: 17.43311047554016
idx: 300, loss: 0.20175638794898987 time take: 14.056542158126831
epoch 82 time taken: 63.99685549736023s

idx: 200, loss: 1.4984065294265747 time take: 6.61432409286499
idx: 300, loss: 1.503967046737671 time take: 6.629724502563477
epoch 4 time taken: 26.187664031982422s
Training epoch: 5
idx: 0, loss: 1.4906679391860962 time take: 0.23058271408081055
idx: 100, loss: 1.432595133781433 time take: 6.624086618423462
idx: 200, loss: 1.3607155084609985 time take: 6.612660884857178
idx: 300, loss: 1.4277511835098267 time take: 6.621292352676392
epoch 5 time taken: 26.087666749954224s
Training epoch: 6
idx: 0, loss: 1.429185390472412 time take: 0.23420476913452148
idx: 100, loss: 1.2645304203033447 time take: 6.685285329818726
idx: 200, loss: 1.4254693984985352 time take: 6.712737560272217
idx: 300, loss: 1.2264225482940674 time take: 6.617182970046997
epoch 6 time taken: 26.24608612060547s
Training epoch: 7
idx: 0, loss: 1.2745002508163452 time take: 0.26207756996154785
idx: 100, loss: 1.200764775276184 time take: 6.624682188034058
idx: 200, loss: 1.3311946392059326 time take: 6.616963624954224


idx: 200, loss: 0.5546537041664124 time take: 6.643129110336304
idx: 300, loss: 0.4124358594417572 time take: 6.632961988449097
epoch 30 time taken: 26.18768572807312s
Training epoch: 31
idx: 0, loss: 0.3269745707511902 time take: 0.2265770435333252
idx: 100, loss: 0.3166235089302063 time take: 6.628132104873657
idx: 200, loss: 0.5359580516815186 time take: 6.614710569381714
idx: 300, loss: 0.3869037926197052 time take: 6.6439878940582275
epoch 31 time taken: 26.127283811569214s
Training epoch: 32
idx: 0, loss: 0.32809117436408997 time take: 0.24917221069335938
idx: 100, loss: 0.31216713786125183 time take: 6.6151440143585205
idx: 200, loss: 0.5158452391624451 time take: 6.61550498008728
idx: 300, loss: 0.2816767394542694 time take: 6.613265752792358
epoch 32 time taken: 26.093472957611084s
Training epoch: 33
idx: 0, loss: 0.3418561816215515 time take: 0.23473906517028809
idx: 100, loss: 0.283948689699173 time take: 6.67553448677063
idx: 200, loss: 0.4619256854057312 time take: 6.73145