In [1]:
import os
import cv2
from PIL import Image
import torch
from models import *
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.utils.data as data
import torchvision.datasets as dset
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision import datasets
from torch.autograd import Variable
import matplotlib.pyplot as plt
from sklearn import metrics


def plot_hist(*args, bins=[100, 100], labels=['cifar10', 'svhn'], xlim=[0, 10]):
    
    if len(args) != len(bins) or len(args) != len(labels) or len(bins) != len(labels):
        print('Oops! GRADS, BINS and LABELS must have the same length !')
        raise NotImplementedError
    
    fig = plt.figure(figsize=(12, 9))
    
    in_dist_Grads = torch.tensor(args[0]).detach().cpu().numpy()
    plt.hist(in_dist_Grads, bins=bins[0], density=True, alpha=0.9, color='black', label=labels[0])

    for i in range(1, len(args)):
        out_dist_Grads = torch.tensor(args[i]).detach().cpu().numpy()
        plt.hist(out_dist_Grads, bins=bins[i], density=True, alpha=0.5, label=labels[i])
    
    plt.xlim(xlim[0], xlim[1])
    plt.title(f'In-Dist : {labels[0]}  /  Out-Dist : {[labels[i] for i in range(1, len(args))]}')
    plt.grid(True)
    plt.legend()
    plt.show()

def AUROC(*args, labels=['cifar10', 'svhn'], plot=True):    
    in_dist_Grads = args[0]
    out_dist_Grads = args[1]
    combined = np.concatenate((in_dist_Grads, out_dist_Grads))
    label_1 = np.ones(len(in_dist_Grads))
    label_2 = np.zeros(len(out_dist_Grads))
    label = np.concatenate((label_1, label_2))
    fpr, tpr, thresholds = metrics.roc_curve(label, combined, pos_label=0)
    #plot_roc_curve(fpr, tpr)
    rocauc = metrics.auc(fpr, tpr)
    title = f'In-dist : {labels[0]}  /  Out-dist : {labels[1]} \n AUC for Gradient Norm is: {rocauc:.6f}'
    if plot:
        fig = plt.figure(figsize=(12, 9))
        plt.plot(fpr, tpr)
        plt.title(title)
        plt.grid(True)
        plt.show()
        plt.close()
    return rocauc

In [2]:
class OPT():
    def __init__(self, train_dist):
        self.dataroot = '../data'
        self.imageSize = 32
        self.workers = 0
        if train_dist == 'cifar10':
            self.nc = 3
        elif train_dist == 'fmnist':
            self.nc = 1


def TEST_loader(train_dist, target_dist, batch_size=10, shuffle=False):
    
    """ Return test_loader for given 'train_dist' and 'target_dist' """
    
    """ train_dist = 'cifar10' or 'fmnist' """
    
    """ target_dist (In-Distribution or Out-of-Distribution)
    
            if train_dist is 'cifar10' (train), target_dist should be one of
                    - cifar10 (test)
                    - svhn (test)     
                    - celeba (test)   
                    - lsun (test)     
                    - cifar100 (test) 
                    - mnist (test)    
                    - fmnist (test)   
                    - kmnist (test)   
                    - omniglot (eval) 
                    - notmnist (small)
                    - trafficsign
                    - noise
                    - constant
            
            if train_dist is 'fmnist' (train), target_dist should be one of
                    - fmnist (test)
                    - svhn (test)     
                    - celeba (test)   
                    - lsun (test)     
                    - cifar10 (test)  
                    - cifar100 (test) 
                    - mnist (test)    
                    - kmnist (test)   
                    - omniglot (eval) 
                    - notmnist (small)
                    - noise
                    - constant
    
    """
    
    preprocess1 = [transforms.Normalize((0.48,), (0.2,))]
    preprocess3 = [transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
    
    if train_dist == 'cifar10':
        opt = OPT('cifar10')

        if target_dist == 'cifar10':
            return test_loader_cifar10(opt, preprocess3, batch_size, shuffle)

        elif target_dist == 'svhn':
            return test_loader_svhn(opt, preprocess3, batch_size, shuffle)

        elif target_dist == 'celeba':
            return test_loader_celeba(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'lsun':
            return test_loader_lsun(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'cifar100':
            return test_loader_cifar100(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'mnist':
            return test_loader_mnist(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'fmnist':
            return test_loader_fmnist(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'kmnist':
            return test_loader_kmnist(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'omniglot':
            return test_loader_omniglot(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'notmnist':
            return test_loader_notmnist(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'trafficsign':
            return test_loader_trafficsign(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'noise':
            return test_loader_noise(opt, preprocess3, batch_size, shuffle)
        
        elif target_dist == 'constant':
            return test_loader_constant(opt, preprocess3, batch_size, shuffle)
        
        else:
            raise NotImplementedError("Oops! Such match of ID & OOD doesn't exist!")

    elif train_dist == 'fmnist':
        opt = OPT('fmnist')

        if target_dist == 'fmnist':
            return test_loader_fmnist(opt, preprocess1, batch_size, shuffle)
            
        elif target_dist == 'svhn':
            return test_loader_svhn(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'celeba':
            return test_loader_celeba(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'lsun':
            return test_loader_lsun(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'cifar10':
            return test_loader_cifar10(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'cifar100':
            return test_loader_cifar100(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'mnist':
            return test_loader_mnist(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'kmnist':
            return test_loader_kmnist(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'omniglot':
            return test_loader_omniglot(opt, preprocess1, batch_size, shuffle)
            
        elif target_dist == 'notmnist':
            return test_loader_notmnist(opt, preprocess1, batch_size, shuffle)
            
        elif target_dist == 'noise':
            return test_loader_noise(opt, preprocess1, batch_size, shuffle)
        
        elif target_dist == 'constant':
            return test_loader_constant(opt, preprocess1, batch_size, shuffle)
        
        else:
            raise NotImplementedError("Oops! Such match of ID & OOD doesn't exist!")
        
    else:
        raise NotImplementedError("Oops! Such match of ID & OOD doesn't exist!")
        

def rgb_to_gray(x):
    return torch.mean(x, dim=0, keepdim=True)


def gray_to_rgb(x):
    return x.repeat(3, 1, 1)


def test_loader_cifar10(opt, preprocess, batch_size, shuffle):
    if opt.nc == 1:
        preprocess = [rgb_to_gray] + preprocess
    dataset_cifar10 = dset.CIFAR10(
        root=opt.dataroot,
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.Resize((opt.imageSize)),
            transforms.ToTensor(),
        ] + preprocess),
    )
    test_loader_cifar10 = data.DataLoader(
        dataset_cifar10,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=int(opt.workers),
    )
    return test_loader_cifar10


def test_loader_svhn(opt, preprocess, batch_size, shuffle):
    if opt.nc == 1:
        preprocess = [rgb_to_gray] + preprocess
    dataset_svhn = dset.SVHN(
        root=opt.dataroot,
        split='test',
        download=True,
        transform=transforms.Compose([
            transforms.Resize((opt.imageSize,opt.imageSize)),
            transforms.ToTensor(),
        ] + preprocess),
    )
    test_loader_svhn = data.DataLoader(
        dataset_svhn,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=int(opt.workers),
    )
    return test_loader_svhn

    
def test_loader_celeba(opt, preprocess, batch_size, shuffle):
    if opt.nc == 1:
        preprocess = [rgb_to_gray] + preprocess
    class CelebA(data.Dataset):
        def __init__(self, db_path, transform=None):
            super(CelebA, self).__init__()
            self.db_path = db_path
            elements = os.listdir(self.db_path)
            self.total_path = [self.db_path + '/' + element for element in elements]
            self.transform = transform

        def __len__(self):
            return len(self.total_path)

        def __getitem__(self, index):
            current_path = self.total_path[index]
            img = cv2.imread(current_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img)
            img = self.transform(img)
            return img

    transform=transforms.Compose([
        transforms.Resize((opt.imageSize, opt.imageSize)),
        transforms.ToTensor(),
    ] + preprocess)

    celeba = CelebA('../../data/celeba/archive', transform=transform)
    test_loader_celeba = data.DataLoader(
        celeba,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
    )
    return test_loader_celeba


def test_loader_lsun(opt, preprocess, batch_size, shuffle):
    if opt.nc == 1:
        preprocess = [rgb_to_gray] + preprocess
    class LSUN(data.Dataset):
        def __init__(self, db_path, categories=['bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'], transform=None):
            super(LSUN, self).__init__()
            self.total_path = []
            for i in range(len(categories)):
                self.db_path = db_path + '/' + categories[i] + '_val'
                elements = os.listdir(self.db_path)
                self.total_path += [self.db_path + '/' + element for element in elements if element[-4:] == '.jpg']
            self.transform = transform

        def __len__(self):
            return len(self.total_path)

        def __getitem__(self, index):
            current_path = self.total_path[index]
            img = cv2.imread(current_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img)
            img = self.transform(img)
            return img

    transform = transforms.Compose([
       # transforms.Resize(opt.imageSize), # Then the size will be H x 32 or 32 x W (32 is smaller)
        transforms.CenterCrop(opt.imageSize),
        transforms.ToTensor(),
    ] + preprocess)

    lsun = LSUN('../../data/LSUN_test', transform=transform)
    test_loader_lsun = data.DataLoader(
        lsun,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
    )
    return test_loader_lsun


def test_loader_cifar100(opt, preprocess, batch_size, shuffle):
    if opt.nc == 1:
        preprocess = [rgb_to_gray] + preprocess
    dataset_cifar100 = dset.CIFAR100(
        root=opt.dataroot,
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.Resize((opt.imageSize)),
            transforms.ToTensor(),
        ] + preprocess),
    )
    test_loader_cifar100 = data.DataLoader(
        dataset_cifar100,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=int(opt.workers),
    )
    return test_loader_cifar100


def test_loader_mnist(opt, preprocess, batch_size, shuffle):
    if opt.nc == 3:
        preprocess = [gray_to_rgb] + preprocess
    dataset_mnist = dset.MNIST(
        root=opt.dataroot,
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.Resize((opt.imageSize, opt.imageSize)),
            transforms.ToTensor(),
        ] + preprocess),
    )
    test_loader_mnist = data.DataLoader(
        dataset_mnist,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=int(opt.workers),
    )
    return test_loader_mnist

    
def test_loader_fmnist(opt, preprocess, batch_size, shuffle):
    if opt.nc == 3:
        preprocess = [gray_to_rgb] + preprocess
    dataset_fmnist = dset.FashionMNIST(
        root=opt.dataroot,
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.Resize((opt.imageSize)),
            transforms.ToTensor(),
        ] + preprocess),
    )
    test_loader_fmnist = data.DataLoader(
        dataset_fmnist,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=int(opt.workers),
    )
    return test_loader_fmnist

    
def test_loader_kmnist(opt, preprocess, batch_size, shuffle):
    if opt.nc == 3:
        preprocess = [gray_to_rgb] + preprocess
    dataset_kmnist = dset.KMNIST(
        root=opt.dataroot,
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.Resize((opt.imageSize, opt.imageSize)),
            transforms.ToTensor(),
        ] + preprocess),
    )
    test_loader_kmnist = data.DataLoader(
        dataset_kmnist,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=int(opt.workers),
    )
    return test_loader_kmnist

    
def test_loader_omniglot(opt, preprocess, batch_size, shuffle):
    if opt.nc == 3:
        preprocess = [gray_to_rgb] + preprocess
    dataset_omniglot = dset.Omniglot(
        root=opt.dataroot, 
        background=False,
        download=True,
        transform=transforms.Compose([
            transforms.Resize((opt.imageSize, opt.imageSize)),
            transforms.ToTensor(),
        ] + preprocess),
    )
    test_loader_omniglot = data.DataLoader(
        dataset_omniglot,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=int(opt.workers),
    )
    return test_loader_omniglot

    
def test_loader_notmnist(opt, preprocess, batch_size, shuffle):
    if opt.nc == 1:
        preprocess = [rgb_to_gray] + preprocess
    class notMNIST(data.Dataset):
        def __init__(self, db_path, transform=None):
            super(notMNIST, self).__init__()
            self.db_path = db_path
            self.total_path = []
            alphabets = os.listdir(self.db_path)
            for alphabet in alphabets:
                path = self.db_path + '/' + alphabet
                elements = os.listdir(path)
                self.total_path += [path + '/' + element for element in elements]
            self.transform = transform

        def __len__(self):
            return len(self.total_path)

        def __getitem__(self, index):
            current_path = self.total_path[index]
            img = cv2.imread(current_path)
            img = Image.fromarray(img)
            img = self.transform(img)
            return img

    transform=transforms.Compose([
        transforms.Resize((opt.imageSize, opt.imageSize)),
        transforms.ToTensor(),
    ] + preprocess)

    notmnist = notMNIST('../../data/notMNIST_small/', transform=transform)
    test_loader_notmnist = data.DataLoader(
        notmnist,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
    )
    return test_loader_notmnist

    
def test_loader_trafficsign(opt, preprocess, batch_size, shuffle):
    if opt.nc == 1:
        preprocess = [rgb_to_gray] + preprocess
    class trafficsign(data.Dataset):
        def __init__(self, db_path, transform=None):
            super(trafficsign, self).__init__()
            self.db_path = db_path
            elements = os.listdir(self.db_path)
            self.total_path = [self.db_path + '/' + element for element in elements]
            self.transform = transform

        def __len__(self):
            return len(self.total_path)

        def __getitem__(self, index):
            current_path = self.total_path[index]
            img = cv2.imread(current_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img)
            img = self.transform(img)
            return img

    transform = transforms.Compose([
        transforms.Resize((opt.imageSize, opt.imageSize)),
        transforms.ToTensor(),
    ] + preprocess)

    ts = trafficsign('../../data/GTSRB_Final_Test_Images/Final_Test/Images', transform=transform)
    test_loader_trafficsign = data.DataLoader(
        ts,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
    )
    return test_loader_trafficsign

    
def test_loader_noise(opt, preprocess, batch_size, shuffle):
    class Noise(data.Dataset):
        def __init__(self, number=10000, transform=None):
            super(Noise, self).__init__()
            self.transform = transform
            self.number = number
            self.total_data = np.random.randint(0, 256, (self.number, opt.nc, 32, 32))

        def __len__(self):
            return self.number

        def __getitem__(self, index):
            array = torch.tensor(self.total_data[index] / 255).float()
            return self.transform(array)

    transform = transforms.Compose(preprocess)

    noise = Noise(transform=transform)
    test_loader_noise = data.DataLoader(
        noise,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
    )
    return test_loader_noise

    
def test_loader_constant(opt, preprocess, batch_size, shuffle):
    class Constant(data.Dataset):
        def __init__(self, number=10000, transform=None):
            super(Constant, self).__init__()
            self.number = number
            self.total_data = torch.randint(0, 256, (self.number, opt.nc, 1, 1))
            self.transform = transform

        def __len__(self):
            return self.number

        def __getitem__(self, index):
            data = self.total_data[index].float()
            data = data.repeat(32 * 32, 1, 1).reshape((-1, 32, 32)) / 255
            return self.transform(data)

    transform = transforms.Compose(preprocess)

    constant = Constant(transform=transform)
    test_loader_constant = data.DataLoader(
        constant,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
    )
    return test_loader_constant

In [3]:
models = {'cifar10': ['svd28', 'dct28', 'gb', '20p', '4px'], 'fmnist': ['svd20', 'svd24', 'dct28', 'gb', '20p', '4px']}
testlists = {'cifar10': ['svhn', 'celeba', 'lsun', 'mnist', 'fmnist', 'notmnist', 'noise', 'constant'],
           'fmnist': ['cifar10', 'svhn', 'celeba', 'lsun', 'mnist', 'notmnist', 'noise', 'constant']}

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def aurocTest(train_dist, blurtype, testlist, plot=True):
    id_testloader = TEST_loader(train_dist, train_dist)
    if train_dist == 'cifar10':
        channel = 3
    elif train_dist == 'fmnist':
        channel = 1   
    net = ResNet345(channel)
    net2 = ResNet342(channel)
    net = net.to(device)
    net2 = net2.to(device)    

    netPATH = f'saved_models/net_{train_dist}_{blurtype}.ckpt'
    net2PATH = f'saved_models/net2_{train_dist}_{blurtype}.ckpt'

    netckpt = torch.load(netPATH)
    net2ckpt = torch.load(net2PATH)

    net.load_state_dict(netckpt)
    net2.load_state_dict(net2ckpt)

    net.eval()
    net2.eval()
    
    id_loss = np.zeros((100, 10))
    for batch_idx, (inputs, targets) in enumerate(id_testloader):
        inputs = inputs.to(device)
#         inputs = Variable(inputs, requires_grad=True)
        outputs = net(inputs)
        outputs2, _ = net2(inputs)
        id_loss[batch_idx] = torch.sum((outputs - outputs2)**2, axis=-1).detach().cpu().numpy()
        if batch_idx == 99:
            break
    id_loss = id_loss.reshape(-1)  
    
    for oodtype in testlist:      
        ood_testloader = TEST_loader(train_dist, oodtype)

        ood_loss = np.zeros((100, 10))
        for batch_idx, inputs in enumerate(ood_testloader):
            try:
                inputs, _ = inputs
            except:
                pass
            inputs = inputs.to(device)
#             inputs = Variable(inputs, requires_grad=True)
            outputs = net(inputs)
            outputs2, _ = net2(inputs)
            ood_loss[batch_idx] = torch.sum((outputs - outputs2)**2, axis=-1).detach().cpu().numpy()
            if batch_idx == 99:
                break
        ood_loss = ood_loss.reshape(-1)    
        if plot:
            plot_hist(id_loss, ood_loss, labels=[train_dist, oodtype])
        auroc = AUROC(id_loss, ood_loss, labels=[train_dist, oodtype], plot=plot)
        print(f'{train_dist}_{blurtype}_{oodtype} : {auroc:.3f}')

In [5]:
print('ID_blurtype_OOD : aurocScore')
for k in models:
    for b in models[k]:
        aurocTest(k, b, testlists[k], plot=False)

ID_blurtype_OOD : aurocScore
Files already downloaded and verified
Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ../data\test_32x32.mat


0it [00:00, ?it/s]

KeyboardInterrupt: 