In [None]:
# -*- coding: utf-8 -*-
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

"""
Created on Sat Sep 19 20:55:56 2015

@author: liangshiyu
"""

from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
#import matplotlib.pyplot as plt
import numpy as np
import time
from scipy import misc
import random

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
def tpr95(id, ood, id_fi, ood_fi, start, end, gap):
    total = 0.0
    fpr = 0.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(id >= delta)) / float(len(id))
        error2 = np.sum(np.sum(ood > delta)) / float(len(ood))
        if tpr <= 0.9505 and tpr >= 0.9495:
            fpr += error2
            total += 1
    print("total", total)            
    fprNew = fpr/total
    
    total = 0.0
    fpr = 0.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(id_fi >= delta)) / float(len(id_fi))
        error2 = np.sum(np.sum(ood_fi > delta)) / float(len(ood_fi))
        if tpr <= 0.9505 and tpr >= 0.9495:
            fpr += error2
            total += 1
    print("total", total)                        
    fprNew_fi = fpr/total
            
    return fprNew, fprNew_fi

In [None]:
def auroc(id, ood, id_fi, ood_fi, start, end, gap):
    aurocNew = 0.0
    fprTemp = 1.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(id >= delta)) / float(len(id))
        fpr = np.sum(np.sum(ood >= delta)) / float(len(ood))
        aurocNew += (-fpr+fprTemp)*tpr
        fprTemp = fpr
    aurocNew += fpr * tpr
    
    aurocNew_fi = 0.0
    fprTemp_fi = 1.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(id_fi >= delta)) / float(len(id_fi))
        fpr = np.sum(np.sum(ood_fi >= delta)) / float(len(ood_fi))
        aurocNew_fi += (-fpr+fprTemp_fi)*tpr
        fprTemp_fi = fpr
    aurocNew_fi += fpr * tpr    
    
    return aurocNew, aurocNew_fi

In [None]:
def auprIn(id, ood, id_fi, ood_fi, start, end, gap):
    auprNew = 0.0
    recallTemp = 1.0
    for delta in np.arange(start, end, gap):
        tp = np.sum(np.sum(id >= delta)) / float(len(id))
        fp = np.sum(np.sum(ood >= delta)) / float(len(ood))
        if tp + fp == 0: continue
        precision = tp / (tp + fp)
        recall = tp
        #precisionVec.append(precision)
        #recallVec.append(recall)
        auprNew += (recallTemp-recall)*precision
        recallTemp = recall
    auprNew += recall * precision
    
    auprNew_fi = 0.0
    recallTemp_fi = 1.0
    for delta in np.arange(start, end, gap):
        tp = np.sum(np.sum(id_fi >= delta)) / float(len(id_fi))
        fp = np.sum(np.sum(ood_fi >= delta)) / float(len(ood_fi))
        if tp + fp == 0: continue
        precision = tp / (tp + fp)
        recall = tp
        #precisionVec.append(precision)
        #recallVec.append(recall)
        auprNew_fi += (recallTemp_fi-recall)*precision
        recallTemp_fi = recall
    auprNew_fi += recall * precision
        
    return auprNew, auprNew_fi

In [None]:
def auprOut(id, ood, id_fi, ood_fi, start, end, gap):
    auprNew = 0.0
    recallTemp = 1.0
    for delta in np.arange(end, start, -gap):
        fp = np.sum(np.sum(id < delta)) / float(len(id))
        tp = np.sum(np.sum(ood < delta)) / float(len(ood))
        if tp + fp == 0: break
        precision = tp / (tp + fp)
        recall = tp
        auprNew += (recallTemp-recall)*precision
        recallTemp = recall
    auprNew += recall * precision
    
    auprNew_fi = 0.0
    recallTemp_fi = 1.0
    for delta in np.arange(end, start, -gap):
        fp = np.sum(np.sum(id_fi < delta)) / float(len(id_fi))
        tp = np.sum(np.sum(ood_fi < delta)) / float(len(ood_fi))
        if tp + fp == 0: break
        precision = tp / (tp + fp)
        recall = tp
        auprNew_fi += (recallTemp_fi-recall)*precision
        recallTemp_fi = recall
    auprNew_fi += recall * precision    
    
    return auprNew, auprNew_fi

In [None]:
def detection(id, ood, id_fi, ood_fi, start, end, gap):
    errorNew = 1.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(id < delta)) / float(len(id))
        error2 = np.sum(np.sum(ood > delta)) / float(len(ood))
        errorNew = np.minimum(errorNew, (tpr+error2)/2.0)
        
    errorNew_fi = 1.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(id_fi < delta)) / float(len(id_fi))
        error2 = np.sum(np.sum(ood_fi > delta)) / float(len(ood_fi))
        errorNew_fi = np.minimum(errorNew_fi, (tpr+error2)/2.0)
            
    return errorNew, errorNew_fi

In [None]:
def metric(model_name='', id_dataset='', ood_dataset='', file_path=''): 
    in_txt = np.loadtxt(f'{file_path}/confidence_Our_In.txt', delimiter=',') 
    out_txt = np.loadtxt(f'{file_path}/confidence_Our_Out.txt', delimiter=',')
    in_txt_fi = np.loadtxt(f'{file_path}/confidence_Our_In_fi.txt', delimiter=',') 
    out_txt_fi = np.loadtxt(f'{file_path}/confidence_Our_Out_fi.txt', delimiter=',') 
    
    id = in_txt[:, 2]
    ood = out_txt[:, 2]
    id_fi = in_txt_fi[:, 2]
    ood_fi = out_txt_fi[:, 2]
    
    if id_dataset == 'cifar10':
        start = 0.1
        end = 0.12
    elif id_dataset == 'cifar100':
        start = 0.01
        end = 0.0104    
        
    gap = (end- start)/100000   
       
    fprNew, fprNew_fi = tpr95(id, ood, id_fi, ood_fi, start, end, gap)
    errorNew, errorNew_fi = detection(id, ood, id_fi, ood_fi, start, end, gap)
    aurocNew, aurocNew_fi = auroc(id, ood, id_fi, ood_fi, start, end, gap)
    auprinNew, auprinNew_fi = auprIn(id, ood, id_fi, ood_fi, start, end, gap)
    auproutNew, auproutNew_fi = auprOut(id, ood, id_fi, ood_fi, start, end, gap)
    
    print("{:31}{:>22}".format("Neural network architecture:", model_name))
    print("{:31}{:>22}".format("In-distribution dataset:", id_dataset))
    print("{:31}{:>22}".format("Out-of-distribution dataset:", ood_dataset))
    print("")
    print("{:>34}{:>19}".format("ODIN", "ODIN_FI"))
    print("{:20}{:13.1f}%{:>18.1f}%".format("AUROC:",aurocNew*100, aurocNew_fi*100))
    print("{:20}{:13.1f}%{:>18.1f}% ".format("FPR at TPR 95%:",fprNew*100, fprNew_fi*100))
    print("{:20}{:13.1f}%{:>18.1f}%".format("Detection error:",errorNew*100, errorNew_fi*100))
    # print("{:20}{:13.1f}%{:>18.1f}%".format("AUPR In:",auprinBase*100, auprinNew*100))
    # print("{:20}{:13.1f}%{:>18.1f}%".format("AUPR Out:",auproutBase*100, auproutNew*100))
    print("")
    
    return round(aurocNew*100, 2), round(fprNew*100, 2), round(errorNew*100, 2), round(aurocNew_fi*100, 2), round(fprNew_fi*100, 2), round(errorNew_fi*100, 2)