<center><h1>ResNet: MNIST (Leave-one-out)</h1></center>

## Imports

In [1]:
from __future__ import division,print_function

%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
from tqdm import tqdm_notebook as tqdm

import random
import matplotlib.pyplot as plt
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable, grad
from torchvision import datasets, transforms as trn
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader

import calculate_log as callog
out_target = 1
model_type = 'scratch'


# import warnings
# warnings.filterwarnings('ignore')

In [2]:
torch.cuda.set_device(0) #Select the GPU

## Model definition

In [3]:
from PIL import Image

class SaigeDataset3(torch.utils.data.Dataset):
    def __init__(self, data_root, split_root, dataset, split, transform, targets):
        """
            data_root(str) : Root directory of datasets (e.g. "/home/sr2/HDD2/Openset/")
            split_root(str) : Root directroy of split file (e.g. "/home/sr2/Hyeokjun/OOD-saige/datasets/data_split/")
            dataset(str) : dataset name
            split(str) : ['train', 'valid', 'test']
            transform(torchvision transform) : image transform
            targets(list of str) : using targets
        """
        self.data_root = data_root
        self.split_root = split_root
        self.dataset = dataset
        self.transform = transform
        self.targets = targets
        self.data_list = []
        f = open(os.path.join(split_root, split + ".txt"), "r")
        while True:
            line = f.readline()
            if not line: break
            [target, _] = line[:-1].split("/")
            # Transform target
            if target in targets:
                target = targets.index(target)
                self.data_list.append((target, line[:-1]))
                
    def __getitem__(self, idx):
        if isinstance(idx, int):
            (target, fpath) = self.data_list[idx]
            img = Image.open(os.path.join(self.data_root, fpath))
            rgb= img.convert("RGB")
            img = self.transform(rgb)
            return img, target

        else:
            idx=range(100000)[idx]
            if type(idx)==int : idx = [idx]
            for count,x in enumerate(idx):
                try: (target, fpath) = self.data_list[x]
                except: break
                img=Image.open(os.path.join(self.data_root,fpath))
                rgb=img.convert("RGB")
                img= self.transform(rgb)
                img=torch.unsqueeze(img,dim=0)
                if count==0:
                    total_img=[[img,torch.tensor([target])]]
                else:
                    total_img.append([img,torch.tensor([target])])
            return total_img

    
    def data(self):
        return img.data
       
    def __len__(self):
        return len(self.data_list)

class SaigeDataset(torch.utils.data.Dataset):
    def __init__(self, data_root, split_root, dataset, split, transform, targets):
        """
            data_root(str) : Root directory of datasets (e.g. "/home/sr2/HDD2/Openset/")
            split_root(str) : Root directroy of split file (e.g. "/home/sr2/Hyeokjun/OOD-saige/datasets/data_split/")
            dataset(str) : dataset name
            split(str) : ['train', 'valid', 'test']
            transform(torchvision transform) : image transform
            targets(list of str) : using targets
        """
        self.data_root = data_root
        self.split_root = split_root
        self.dataset = dataset
        self.transform = transform
        self.targets = targets
        
        self.data_list = []
        f = open(os.path.join(split_root, split + ".txt"), "r")
        
        while True:
            line = f.readline()
            if not line: break
            [target, _] = line[:-1].split("/")
            # Transform target
            if target in targets:
                target = targets.index(target)
                self.data_list.append((target, line[:-1]))
                
            
    def __getitem__(self, idx):
        if isinstance(idx, int):
            (target, fpath) = self.data_list[idx]
            img = Image.open(os.path.join(self.data_root, fpath))
            img = self.transform(img)
            return img, target
        else:
            idx=range(100000)[idx]
            if type(idx)==int : idx = [idx]
            for count,x in enumerate(idx):
                try: (target, fpath) = self.data_list[x]
                except: break
                img=Image.open(os.path.join(self.data_root,fpath))
                img= self.transform(img)
                img=torch.unsqueeze(img,dim=0)
                if count==0:
                    total_img=[[img,torch.tensor([target])]]
                else:
                    total_img.append([img,torch.tensor([target])])
            return total_img
       
    
    def __len__(self):
        return len(self.data_list)

class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, split, targets):
        """
            data_root(str) : Root directory of datasets (e.g. "/home/sr2/HDD2/Openset/")
            split_root(str) : Root directroy of split file (e.g. "/home/sr2/Hyeokjun/OOD-saige/datasets/data_split/")
            dataset(str) : dataset name
            split(str) : ['train', 'valid', 'test']
            transform(torchvision transform) : image transform
            targets(list of str) : using targets
        """
        if split=='train':
            self.data_root = './mnist_png/training'
        else:
            self.data_root = './mnist_png/testing'
        # self.dataset = dataset
        self.transform = trn.Compose([trn.Pad(2),trn.ToTensor()])
        self.targets = targets
        self.data_list = []
        # f = open(os.path.join(split_root, split + ".txt"), "r")
        for direc in os.listdir(self.data_root):
            for image in os.listdir(os.path.join(self.data_root,direc)):
                target=int(direc) # Transform target
                if target in targets:
                    target = targets.index(target)
                    self.data_list.append((target, os.path.join(direc,image)))
                    # print(target)
                    # print(os.path.join(direc,image))
          
    def __getitem__(self, idx):
        if isinstance(idx, int):
            (target, fpath) = self.data_list[idx]
            img = Image.open(os.path.join(self.data_root, fpath))
            rgb= img.convert("RGB")
            img = self.transform(rgb)
            return img, target
        else:
            idx=range(1000000)[idx]
            if type(idx)==int : idx = [idx]
            for count,x in enumerate(idx):
                try: (target, fpath) = self.data_list[x]
                except: break
                img=Image.open(os.path.join(self.data_root,fpath))
                rgb= img.convert("RGB")
                img= self.transform(rgb)
                img=torch.unsqueeze(img,dim=0)
                if count==0:
                    total_img=[[img,torch.tensor([target])]]
                else:
                    total_img.append([img,torch.tensor([target])])
            return total_img

    def data(self):
        return img.data
       
    def __len__(self):
        return len(self.data_list)

class TrafficDataset(torch.utils.data.Dataset):
    def __init__(self, split, targets):
        """
            data_root(str) : Root directory of datasets (e.g. "/home/sr2/HDD2/Openset/")
            split_root(str) : Root directroy of split file (e.g. "/home/sr2/Hyeokjun/OOD-saige/datasets/data_split/")
            dataset(str) : dataset name
            split(str) : ['train', 'valid', 'test']
            transform(torchvision transform) : image transform
            targets(list of str) : using targets
        """
        self.split=split
        if self.split=='train':
            self.data_root = './GTSRB/Final_Training/Images'
        else:
            self.data_root = './GTSRB/Final_Test/Images'
        self.targets = targets
        self.transform=trn.Compose([trn.Resize([64,64]),trn.ToTensor()])
        self.data_list = []

        self.images=[]
        self.labels=[]

        if self.split=='train':
            # loop over all 42 classes
            for c in self.targets:
                prefix = self.data_root + '/' + format(c, '05d') + '/' # subdirectory for class
                gtFile = open(prefix + 'GT-'+ format(c, '05d') + '.csv') # annotations file
                gtReader = csv.reader(gtFile, delimiter=';') # csv parser for annotations file
                next(gtReader) # skip header
                # loop over all images in current annotations file
                for row in gtReader:
                    if int(row[7]) in self.targets:
                        target=self.targets.index(int(row[7]))
                        img=plt.imread(prefix + row[0])
                        img=Image.fromarray(img)
                        img=self.transform(img)
                        self.images.append(img) # the 1th column is the filename
                        self.labels.append(torch.tensor(target)) # the 8th column is the label
                gtFile.close()
        else:
            prefix = self.data_root + '/' # subdirectory for class
            gtFile = open(prefix + 'GT-final_test.csv') # annotations file
            gtReader = csv.reader(gtFile, delimiter=';') # csv parser for annotations file
            next(gtReader) # skip header
            # loop over all images in current annotations file
            for row in gtReader:
                if int(row[7]) in self.targets:
                    target=self.targets.index(int(row[7]))
                    img=plt.imread(prefix + row[0])
                    img=Image.fromarray(img)
                    img=self.transform(img)
                    self.images.append(img) # the 1th column is the filename
                    self.labels.append(torch.tensor(target)) # the 8th column is the label
                # labels.append(row[7]) # the 8th column is the label
            gtFile.close()
        # f = open(os.path.join(split_root, split + ".txt"), "r")

    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.images[idx], self.labels[idx]
        else:
            idx=range(1000000)[idx]
            if type(idx)==int : idx = [idx]
            for count,x in enumerate(idx):
                if count==0:
                    total_img=[[self.images[x],self.labels[x]]]
                else:
                    try:
                        total_img.append([self.images[x],self.labels[x]])
                    except: break
            return total_img
       
    def __len__(self):
        return len(self.labels)

def getDataLoader(ds_cfg, dl_cfg, split, num_samples=10000):
    if split == 'train':
        train = True
        transform = ds_cfg['train_transform']
    else:
        train = False
        transform = ds_cfg['valid_transform']
        
    if 'split' in ds_cfg.keys() and ds_cfg['split'] == 'train':
           split = 'train'
    elif 'split' in ds_cfg.keys() and ds_cfg['split'] == 'valid':
            split = 'valid'
    elif 'split' in ds_cfg.keys() and ds_cfg['split'] == 'test':
            split = 'test'
    else:
            pass
    if ds_cfg['dataset'] in ['SDI/34Ah','SDI/37Ah','SDI/60Ah']:
        dataset = SaigeDataset(data_root=ds_cfg['data_root'],
                                            split_root=ds_cfg['split_root'],
                                            dataset=ds_cfg['dataset'],
                                            split=split,
                                            transform=transform,
                                            targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,batch_size=ds_cfg['batch_size'], shuffle=train, num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))
    
    elif ds_cfg['dataset'] in ['Traffic']:
        dataset = TrafficDataset(split=split,targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,batch_size=ds_cfg['batch_size'], shuffle=train,num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))

    elif ds_cfg['dataset'] in ['MNIST']:
        dataset = MNISTDataset(split=split,targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,batch_size=ds_cfg['batch_size'], shuffle=train,num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))


    else :
        dataset = SaigeDataset3(data_root=ds_cfg['data_root'],
                                            split_root=ds_cfg['split_root'],
                                            dataset=ds_cfg['dataset'],
                                            split=split,
                                            transform=transform,
                                            targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,batch_size=ds_cfg['batch_size'], shuffle=train, num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))

    return loader

print('Done')

    

Done


In [4]:

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        torch_model.record(out)
        out = self.bn1(out)
        out = self.relu(out)
        torch_model.record(out)
        out = self.conv2(out)
        torch_model.record(out)
        out = self.bn2(out)
        torch_model.record(out)
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        torch_model.record(out)

        out = self.relu(out)
        torch_model.record(out)

        return out

class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = conv3x3(3, self.inplanes)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.collecting = False

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)
    
    def record(self, t):
        if self.collecting:
            self.gram_feats.append(t)
    
    def gram_feature_list(self,x):
        self.collecting = True
        self.gram_feats = []
        self.forward(x)
        self.collecting = False
        temp = self.gram_feats
        self.gram_feats = []
        return temp
    
    def get_min_max(self, data, power):
        mins = []
        maxs = []
        which_layer=1
        for i in range(0,len(data),128):
            batch = data[i:i+128].cuda()
            feat_list = self.gram_feature_list(batch)
            for L1,feat_L in enumerate(feat_list):
                if L1%which_layer==0:
                    L=L1//which_layer
                    if L==len(mins):
                        mins.append([None]*len(power))
                        maxs.append([None]*len(power))

                    for p,P in enumerate(power):
                        g_p = G_p(feat_L,P)

                        current_min = g_p.min(dim=0,keepdim=True)[0]
                        current_max = g_p.max(dim=0,keepdim=True)[0]

                        if mins[L][p] is None:
                            mins[L][p] = current_min
                            maxs[L][p] = current_max
                        else:
                            mins[L][p] = torch.min(current_min,mins[L][p])
                            maxs[L][p] = torch.max(current_max,maxs[L][p])
        
        return mins,maxs


    def get_min_max_real(self, data, power):
        mins = []
        maxs = []
        which_layer=1
        for i in range(0,len(data),128):
            batch = data[i:i+128].cuda()
            feat_list = self.gram_feature_list(batch)
            for L1,feat_L in enumerate(feat_list):
                if L1%which_layer==0:
                    L=L1//which_layer
                    if L==len(mins):
                        mins.append([None]*len(power))
                        maxs.append([None]*len(power))

                    for p,P in enumerate(power):
                        g_p = G_p_entire(feat_L,P)

                        current_max = g_p.max(dim=0,keepdim=True)[0]
                        current_min = g_p.min(dim=0,keepdim=True)[0]

                        if mins[L][p] is None:
                            mins[L][p] = current_min
                            maxs[L][p] = current_max
                        else:
                            mins[L][p] = torch.min(current_min,mins[L][p])
                            maxs[L][p] = torch.max(current_max,maxs[L][p])
        
        return mins,maxs

    def get_deviations(self,data,power,mins,maxs):
        deviations = []
        which_layer=1
        for i in range(0,len(data),128):            
            batch = data[i:i+128].cuda()
            feat_list = self.gram_feature_list(batch)
            batch_deviations = []
            for L1,feat_L in enumerate(feat_list):
                if L1%which_layer==0:
                    L=L1//which_layer
                    dev = 0
                    for p,P in enumerate(power):
                        g_p = G_p(feat_L,P)

                        dev +=  (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)
                        dev +=  (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)
                    batch_deviations.append(dev.cpu().detach().numpy())
            batch_deviations = np.concatenate(batch_deviations,axis=1)
            deviations.append(batch_deviations)
        deviations = np.concatenate(deviations,axis=0)
        
        return deviations


def _resnet(arch, block, num_classes, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, num_classes, **kwargs)
    if pretrained:
        state_dict = torch.hub.load_state_dict_from_url(model_urls[arch],
                                            progress=progress)
        state_dict['fc.weight']=state_dict['fc.weight'].data[range(num_classes),:]
        state_dict['fc.bias']=state_dict['fc.bias'].data[range(num_classes)]
        model.load_state_dict(state_dict)
    return model

def resnet34(pretrained=1, num_classes=9,progress=True, **kwargs):
    if pretrained==1:
        pretrained=True
        print("*** Pre Trained : True")
    else: 
        pretrained=False
        print("*** Pre Trained : False")
    return _resnet('resnet34', BasicBlock, num_classes, [3, 4, 6, 3], pretrained, progress,
                **kwargs)
    # model = ResNet_128(BasicBlock, [3,4,6,3], num_classes=cfg['in_dataset']['num_classes'])

torch_model=resnet34(pretrained=0)
if model_type == 'scratch':
    tm=torch.load('./pre_trained/MNIST_new_scratch_out_'+str(out_target)+'/ckpt/checkpoint_epoch_200.pyth',map_location='cpu')
else:
    tm=torch.load('./pre_trained/MNIST_new_pretrained_out_'+str(out_target)+'/ckpt/checkpoint_epoch_40.pyth',map_location='cpu')
torch_model.load_state_dict(tm['model_state'])
torch_model.cuda()
torch_model.params = list(torch_model.parameters())
torch_model.eval()
print("Done")  

*** Pre Trained : False
Done


## Datasets

<b>In-distribution Datasets</b>

In [5]:
import csv
cfg = dict()

cfg['in_dataset']=dict()
cfg['in_dataset']['dataset']='MNIST'
cfg['in_dataset']['batch_size']=256
x=list(range(0,10))
x.remove(out_target)
cfg['in_dataset']['targets']=x
print(x)
cfg['in_dataset']['train_transform']=None #trn.Compose([trn.RandomHorizontalFlip(), trn.ToTensor(), trn.Normalize([0.4214,0.4214,0.4214],[0.2355,0.2355,0.2355])])
cfg['in_dataset']['valid_transform']=None #trn.Compose([trn.ToTensor(), trn.Normalize([0.4214,0.4214,0.4214],[0.2355,0.2355,0.2355])])
cfg['in_dataset']['num_classes']=len(cfg['in_dataset']['targets'])

cfg['dataloader'] = dict()
cfg['dataloader']['num_workers'] = 4
cfg['dataloader']['pin_memory'] = True

train_loader = getDataLoader(ds_cfg=cfg['in_dataset'],
                                dl_cfg=cfg['dataloader'],
                                split="train")
test_loader = getDataLoader(ds_cfg=cfg['in_dataset'],
                                dl_cfg=cfg['dataloader'],
                                split="valid")


[0, 2, 3, 4, 5, 6, 7, 8, 9]
Dataset MNIST ready.
Dataset MNIST ready.


In [6]:
def getDataset(ds_cfg, dl_cfg, split):
    if split == 'train':
        train = True
        transform = ds_cfg['train_transform']
    else:
        train = False
        transform = ds_cfg['valid_transform']
        
    if 'split' in ds_cfg.keys() and ds_cfg['split'] == 'train':
           split = 'train'
    elif 'split' in ds_cfg.keys() and ds_cfg['split'] == 'valid':
            split = 'valid'
    elif 'split' in ds_cfg.keys() and ds_cfg['split'] == 'test':
            split = 'test'
    else:
            pass
    if ds_cfg['dataset'] in ['SDI/34Ah','SDI/37Ah','SDI/60Ah']:
        dataset = SaigeDataset(data_root=ds_cfg['data_root'],
                                            split_root=ds_cfg['split_root'],
                                            dataset=ds_cfg['dataset'],
                                            split=split,
                                            transform=transform,
                                            targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,
                            batch_size=ds_cfg['batch_size'], shuffle=train,
                            num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))

    elif ds_cfg['dataset'] in ['Traffic']:
        dataset = TrafficDataset(split=split,targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,
                            batch_size=ds_cfg['batch_size'], shuffle=train,
                            num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))
        
    elif ds_cfg['dataset'] in ['MNIST']:
        dataset = MNISTDataset(split=split,targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,
                            batch_size=ds_cfg['batch_size'], shuffle=train,
                            num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))

    else :
        dataset = SaigeDataset3(data_root=ds_cfg['data_root'],
                                            split_root=ds_cfg['split_root'],
                                            dataset=ds_cfg['dataset'],
                                            split=split,
                                            transform=transform,
                                            targets=ds_cfg['targets'])
        number= dataset.__len__()
        loader = DataLoader(dataset,
                            batch_size=ds_cfg['batch_size'], shuffle=train,
                            num_workers=dl_cfg['num_workers'], pin_memory=dl_cfg['pin_memory'])
        print('Dataset {} ready.'.format(ds_cfg['dataset']))

    return dataset


data_train = getDataset(ds_cfg=cfg['in_dataset'],
                                dl_cfg=cfg['dataloader'],
                                split="train")


Dataset MNIST ready.


In [7]:
data = getDataset(ds_cfg=cfg['in_dataset'],
                                dl_cfg=cfg['dataloader'],
                                split="valid")

Dataset MNIST ready.


In [8]:
torch_model.eval()
correct = 0
total = 0
for x,y in test_loader:
    x = x.cuda()
    y = y.numpy()
    correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()
    total += y.shape[0]
print("Accuracy: ",correct/total)


Accuracy:  0.9934574168076706


<b>Out-of-distribution Datasets</b>

## Code for Detecting OODs

<b> Extract predictions for train and test data </b>

In [9]:
# from ipywidgets import IntProgress

train_preds = []
train_confs = []
train_logits = []
batch_size=16
for idx in tqdm(range(0,len(data_train),batch_size)):

    # batch = torch.squeeze(batch, dim=1).cuda()
    batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+batch_size]]),dim=1).cuda()
    
    logits = torch_model(batch)
    confs = F.softmax(logits,dim=1).cpu().detach().numpy()
    preds = np.argmax(confs,axis=1)
    logits = (logits.cpu().detach().numpy())

    train_confs.extend(np.max(confs,axis=1))    
    train_preds.extend(preds)
    train_logits.extend(logits)
print("Done")

test_preds = []
test_confs = []
test_logits = []

for idx in tqdm(range(0,len(data),batch_size)):
    batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+batch_size]]),dim=1).cuda()
    
    logits = torch_model(batch)
    confs = F.softmax(logits,dim=1).cpu().detach().numpy()
    preds = np.argmax(confs,axis=1)
    logits = (logits.cpu().detach().numpy())

    test_confs.extend(np.max(confs,axis=1))    
    test_preds.extend(preds)
    test_logits.extend(logits)
print("Done")

HBox(children=(FloatProgress(value=0.0, max=3329.0), HTML(value='')))


Done


HBox(children=(FloatProgress(value=0.0, max=555.0), HTML(value='')))


Done


<b> Code for detecting OODs by identifying anomalies in correlations </b>

In [10]:
import calculate_log as callog

def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):
    average_results = {}
    for i in range(1,11):
        random.seed(i)
        
        validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))
        test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))

        validation = all_test_deviations[validation_indices]
        test_deviations = all_test_deviations[test_indices]

        t95 = validation.mean(axis=0)+10**-7
        if not normalize:
            t95 = np.ones_like(t95)
        test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)
        ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)
        
        results = callog.compute_metric(-test_deviations,-ood_deviations)
        for m in results:
            average_results[m] = average_results.get(m,0)+results[m]
    
    for m in average_results:
        average_results[m] /= i
    if verbose:
        callog.print_results(average_results)
    return average_results

def cpu(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cpu()
    return ob

def cuda(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cuda()
    return ob

class Detector:
    def __init__(self):
        self.all_test_deviations = None
        self.mins = {}
        self.maxs = {}
        self.mins_real = {}
        self.maxs_real = {}
        self.gram_inside_ind = dict()
        self.gram_inside_ood = dict()
        self.classes = range(9)
    
    def compute_minmaxs(self,data_train,POWERS=[10]):
        print("Start")
        for PRED in tqdm(self.classes):
            print('current prediction: {}'.format(PRED))
            train_indices = np.where(np.array(train_preds)==PRED)[0]
            # if train_indices==0: break
            print(len(train_indices))
            train_PRED = torch.squeeze(torch.stack([data_train[i][0][0] for i in train_indices]),dim=1)
            mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)
            # mins_real,maxs_real = torch_model.get_min_max_real(train_PRED,power=POWERS)
            # gram_l_p = torch_model.get_entire_Gram(train_PRED,power=POWERS)
            # self.grams[PRED]=gram_l_p
            self.mins[PRED] = cpu(mins)
            self.maxs[PRED] = cpu(maxs)
            # self.mins_real[PRED]=cpu(mins_real)
            # self.maxs_real[PRED]=cpu(maxs_real)
            torch.cuda.empty_cache()
    
    def compute_test_deviations(self,data,POWERS=[10]):
        all_test_deviations = None
        test_classes = []
        for PRED in tqdm(self.classes):
            test_indices = np.where(np.array(test_preds)==PRED)[0]
            # if len(test_indices)==0:
            print(len(test_indices))
            #     print('Passing')
            #     pass
            # print(type(data[1][0]))
            # for i in test_indices:
            #     print(data[i][0])
            x=torch.stack([data[i][0][0] for i in test_indices])
            test_PRED = torch.squeeze(x,dim=1)
            # gram_inside = torch_model.get_entire_Gram(test_PRED,power=POWERS, mins_real=self.mins_real, maxs_real=self.maxs_real, PRED=PRED)
            # self.gram_inside_ind[PRED]=gram_inside
            test_confs_PRED = np.array([test_confs[i] for i in test_indices])
            
            test_classes.extend([PRED]*len(test_indices))
            
            mins = cuda(self.mins[PRED])
            maxs = cuda(self.maxs[PRED])
            test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]
            cpu(mins)
            cpu(maxs)
            if all_test_deviations is None:
                all_test_deviations = test_deviations
            else:
                all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)
            torch.cuda.empty_cache()
        self.all_test_deviations = all_test_deviations
        
        self.test_classes = np.array(test_classes)
    
    def compute_ood_deviations(self,ood,POWERS=[10]):
        ood_preds = []
        ood_confs = []
        batch_size=16
        for idx in range(0,len(ood),batch_size):
            batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+batch_size]]),dim=1).cuda()
            logits = torch_model(batch)
            confs = F.softmax(logits,dim=1).cpu().detach().numpy()
            preds = np.argmax(confs,axis=1)
            
            ood_confs.extend(np.max(confs,axis=1))
            ood_preds.extend(preds)  
            torch.cuda.empty_cache()
        print("Done")
        
        ood_classes = []
        all_ood_deviations = None
        for PRED in tqdm(self.classes):
            ood_indices = np.where(np.array(ood_preds)==PRED)[0]
            if len(ood_indices)==0:
                continue
            ood_classes.extend([PRED]*len(ood_indices))
            
            ood_PRED = torch.squeeze(torch.stack([ood[i][0][0] for i in ood_indices]),dim=1)
            # gram_entire_values,gram_entire_indexes = torch_model.get_entire_Gram(ood_PRED,power=POWERS)
            # self.grams_values[PRED]=gram_entire_values
            # self.grams_indexes[PRED]=gram_entire_indexes

            ood_confs_PRED =  np.array([ood_confs[i] for i in ood_indices])
            mins = cuda(self.mins[PRED])
            maxs = cuda(self.maxs[PRED])
            ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]
            cpu(self.mins[PRED])
            cpu(self.maxs[PRED])            
            if all_ood_deviations is None:
                all_ood_deviations = ood_deviations
            else:
                all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)
            torch.cuda.empty_cache()
            
        self.ood_classes = np.array(ood_classes)
        
        average_results = detect(self.all_test_deviations,all_ood_deviations)
        return average_results, self.all_test_deviations, all_ood_deviations


<center><h1> Results </h1></center>

In [11]:
import time

def G_p(ob, p):
    temp = ob.detach()
    
    temp = temp**p
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) 
    temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)
    
    return temp

def G_p_entire(ob, p):
    temp = ob.detach()
    temp = temp**p
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1))))
    temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],temp.shape[1],-1)
    return temp

detector = Detector()
detector.compute_minmaxs(data_train,POWERS=range(1,10))


Start


HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

current prediction: 0
5923
current prediction: 1
5958
current prediction: 2
6131
current prediction: 3
5842
current prediction: 4
5421
current prediction: 5
5918
current prediction: 6
6265
current prediction: 7
5851
current prediction: 8
5949



In [12]:

detector.compute_test_deviations(data,POWERS=range(1,10))





HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

987
1030
1013
985
894
954
1038
962
1002



In [13]:
cfg['out_dataset']=dict()
cfg['out_dataset']['dataset']='MNIST'
cfg['out_dataset']['batch_size']=128
cfg['out_dataset']['targets']=[out_target]
cfg['out_dataset']['train_transform']=None #trn.Compose([trn.RandomHorizontalFlip(), trn.ToTensor(), trn.Normalize([0.4214,0.4214,0.4214],[0.2355,0.2355,0.2355])])
cfg['out_dataset']['valid_transform']=None #trn.Compose([trn.ToTensor(), trn.Normalize([0.4214,0.4214,0.4214],[0.2355,0.2355,0.2355])])
# cfg['out_dataset']['data_root']='/HDD0/Openset/SDI/34Ah'
# cfg['out_dataset']['split_root']='/HDD0/Openset/data_split/SDI/34Ah'
cfg['out_dataset']['num_classes']=len(cfg['in_dataset']['targets'])

cfg['dataloader'] = dict()
cfg['dataloader']['num_workers'] = 8
cfg['dataloader']['pin_memory'] = True

out_dataset = getDataset(ds_cfg=cfg['out_dataset'],
                                dl_cfg=cfg['dataloader'],
                                split="valid")

print("Real Out Target :{}, target {}".format(cfg['out_dataset']['dataset'],cfg['out_dataset']['targets'][0]))
c10_results = detector.compute_ood_deviations(out_dataset,POWERS=range(1,10))


Dataset MNIST ready.
Real Out Target :MNIST, target 1
Done


HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 99.744 99.800 99.029 99.965 98.469


In [14]:
print("Out Target : CIFAR-10")
transform_test = trn.Compose([
    trn.CenterCrop(size=(32, 32)),
        trn.ToTensor(),
    ])

cifar10 = list(torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, download=True,
                   transform=transform_test),
    batch_size=1, shuffle=True))
c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,10))

Out Target : CIFAR-10
Files already downloaded and verified
Done


HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 100.000 100.000 99.964 99.993 99.995


In [15]:
print("Out Target : SVHN")
transform_test = trn.Compose([
    trn.CenterCrop(size=(32, 32)),
        trn.ToTensor(),
    ])

svhn = list(torch.utils.data.DataLoader(
    datasets.SVHN('data', split="test", download=True,
                   transform=transform_test),
    batch_size=1, shuffle=True))
c10_results = detector.compute_ood_deviations(svhn,POWERS=range(1,10))

Out Target : SVHN
Using downloaded and verified file: data/test_32x32.mat
Done


HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 100.000 99.961 99.974 99.964 99.938
