In [1]:
from __future__ import division,print_function

%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')
from tqdm.notebook import tqdm
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models, datasets
import utils.calculate_log as callog
from my_models import mobilenet
import pandas as pd

# Setting the model

In [2]:
torch_model = mobilenet.Net(models.mobilenet_v2(pretrained=False), 8)
ckpt = torch.load("checkpoints/mobilenet_checkpoint.pth")
torch_model.load_state_dict(ckpt['model_state_dict'])
torch_model.eval()
torch_model.cuda()
print("Done!")



Done!


## Setting the hook register

In [3]:
feat_maps = list()
def _hook_fn(self, input, output):
    feat_maps.append(output)
    

def hook_layers(model):
    hooked_layers = list()
    for layer in torch_model.modules():
        if isinstance(layer, models.mobilenet.ConvBNReLU):        
#         if isinstance(layer, models.mobilenet.ConvBNReLU) or isinstance(layer, nn.Conv2d):
            hooked_layers.append(layer)
    return hooked_layers


def register_layers(layers):
    regs_layers = list()
    for lay in layers:
        regs_layers.append(lay.register_forward_hook(_hook_fn))
    return regs_layers


def unregister_layers(reg_layers):
    for lay in reg_layers:
        lay.remove()
                    

def get_feat_maps(model, batch_img):
    batch_img = batch_img.cuda()
    with torch.no_grad():
        preds = model(batch_img)

    preds = F.softmax(preds, dim=1)
    maps = feat_maps.copy()
    feat_maps.clear()
    return preds, maps

## Setting the hook
hl = hook_layers (torch_model)
rgl = register_layers (hl)
print ("Total number of registered hooked layers:", len(rgl))

Total number of registered hooked layers: 35


# Loading the data

## In distributions

In [4]:
batch_size = 15
trans = transforms.Compose([
#             transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])


sk_train = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/skin_cancer/train/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

sk_test = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/skin_cancer/test/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

sk_val = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/skin_cancer/val/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

## Out-of-distributions

In [5]:
skin_cli = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/skins/clinical/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

In [6]:
skin_derm = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/skins/dermoscopy/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

In [7]:
imgnet = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/imagenet/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

In [8]:
corrupted = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/corrupted/bbox/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

In [9]:
corrupted_70 = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/corrupted/bbox_70/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

In [10]:
nct = torch.utils.data.DataLoader(
                datasets.ImageFolder("data/nct/",transform=trans),     
                batch_size=batch_size,
                shuffle=False)

# Gram-Matrix operations

## Gram matrix operator

In [12]:
def norm_min_max(x):
    ma = torch.max(x,dim=1)[0].unsqueeze(1)
    mi = torch.min(x,dim=1)[0].unsqueeze(1)
    x = (x-mi)/(ma-mi)
    return x

def get_sims_gram_matrix (maps, power): 
    maps = maps ** power    
    maps = maps.reshape(maps.shape[0],maps.shape[1],-1)    
    gram = ((torch.matmul(maps,maps.transpose(dim0=2,dim1=1)))).sum(2)
    gram = (gram.sign()*torch.abs(gram)**(1/power)).reshape(gram.shape[0],-1)  
    gram = norm_min_max(gram)
    return gram


## Considering samples per label

In [13]:
def _get_sim_per_labels(data_loader, power, use_preds=True):
    
    sims_per_label = None
    if not isinstance(power, list) and not isinstance(power, range):
        power = [power]
    
    for data in tqdm(data_loader):
        img_batch, labels = data 
        preds, maps_list = get_feat_maps(torch_model, img_batch)
      
        if use_preds:
            labels = preds.argmax(dim=1)  
                
        if sims_per_label is None:
            sims_per_label = [[[] for _ in range(len(maps_list))] for _ in range(preds.shape[1])]  
           
        for layer, maps in enumerate(maps_list): 
            for p in power:
                sims = get_sims_gram_matrix (maps, p)

                for sim, lab in zip(sims, labels):              
                    sims_per_label[lab.item()][layer].append(sim.cpu()) 
                
    return sims_per_label


def get_min_max_per_label(data_loader, power):
    
    sims_per_label = _get_sim_per_labels(data_loader, power)
    sims_per_label_min = [[[] for _ in range(len(sims_per_label[0]))] for _ in range(len(sims_per_label))] 
    sims_per_label_max = [[[] for _ in range(len(sims_per_label[0]))] for _ in range(len(sims_per_label))] 
    
    
    print ("-- Computing the values...")
    for lab_idx in range(len(sims_per_label)):
        for layer_idx in range(len(sims_per_label[lab_idx])):
            temp = torch.stack(sims_per_label[lab_idx][layer_idx])
            sims_per_label_min[lab_idx][layer_idx] = temp.min(dim=0)[0] 
            sims_per_label_max[lab_idx][layer_idx] = temp.max(dim=0)[0]
    
    del sims_per_label
    
    return sims_per_label_min, sims_per_label_max


def get_layer_gaps(mins, maxs):  
    num_lab, num_lay = len(mins), len(mins[0])    
    gaps = torch.zeros(num_lab, num_lay)
    gaps = gaps.cuda()
    
    for lab in range(num_lab):      
        for layer in range(num_lay):
            gaps[lab][layer] = (maxs[lab][layer]-mins[lab][layer]).sum()
            
    return gaps.cpu().numpy()


def get_dev_scores_per_label(data_loader, power, sims_min, sims_max, ep=10e-6):
    
    if not isinstance(power, list) and not isinstance(power, range):
        power = [power]
    
    dev_scores = list()    
    for data in tqdm(data_loader):
        img_batch, _ = data 
        preds_batch, maps_list = get_feat_maps(torch_model, img_batch)                
        labels = preds_batch.argmax(dim=1)
        batch_scores = list()
       
        for layer, maps in enumerate(maps_list):
                
            score_layer = 0
            for p in power:
                sims = get_sims_gram_matrix (maps, p)  
                _sim_min = torch.zeros(sims.shape[0], sims.shape[1]).cuda()
                _sim_max = torch.zeros(sims.shape[0], sims.shape[1]).cuda()
            
                for k, lab in enumerate(labels):
                    _sim_min[k] = sims_min[lab.item()][layer]
                    _sim_max[k] = sims_max[lab.item()][layer]            
            
                score_layer += (F.relu(_sim_min-sims)/torch.abs(_sim_min+ep)).sum(dim=1, keepdim=True)
                score_layer += (F.relu(sims-_sim_max)/torch.abs(_sim_max+ep)).sum(dim=1, keepdim=True)
           
            batch_scores.append(score_layer)            
            
        dev_scores.append(torch.cat(batch_scores, dim=1)) 

    return torch.cat(dev_scores).cpu().numpy()

In [14]:
def detect_mean(all_test_std, all_ood_std, gaps=None): 
    
    avg_results = dict()
    indices = list(range(len(all_test_std)))
    split = int(np.floor(0.1 * len(all_test_std))) 
    for i in range(1,11):
        np.random.seed(i)
        np.random.shuffle(indices)
        
        val_std = all_test_std[indices[:split]]
        test_std = all_test_std[indices[split:]]
        
        if gaps is not None:
            t95 = (val_std.sum(axis=0) + gaps.mean(0))
        else:
            t95 = val_std.mean(axis=0) + 10**-7
        
        test_std = ((test_std)/t95[np.newaxis,:]).sum(axis=1)
        ood_std = ((all_ood_std)/t95[np.newaxis,:]).sum(axis=1)

        results = callog.compute_metric(-test_std,-ood_std)  

        for m in results:
            avg_results[m] = avg_results.get(m,0)+results[m]
    
    for m in avg_results:
        avg_results[m] /= i
        
        
    callog.print_results(avg_results)
    
    return avg_results

In [15]:
def detect(all_test_std, all_ood_std):     
    
    indices = list(range(len(all_test_std)))
    split = int(np.floor(0.1 * len(all_test_std))) 
    np.random.seed(10)
    np.random.shuffle(indices)
        
    val_std = all_test_std[indices[:split]]
    test_std = all_test_std[indices[split:]]
        
    t95 = val_std.mean(axis=0) + 10**-7
        
    test_std = ((test_std)/t95[np.newaxis,:]).sum(axis=1)
    ood_std = ((all_ood_std)/t95[np.newaxis,:]).sum(axis=1)

    results = callog.compute_metric(-test_std,-ood_std)  

    callog.print_results(results)
    
    return results

# OOD detection per label

In [16]:
power = 1

print ("- Getting mins/maxs")
mins, maxs = get_min_max_per_label(sk_train, power)

print("- Getting the gaps")
gaps = get_layer_gaps(mins, maxs) 

print ("- Getting test stdevs")
sk_test_stdev = get_dev_scores_per_label(sk_test, power, mins, maxs)

print ("- Getting test stdevs")
sk_val_stdev = get_dev_scores_per_label(sk_val, power, mins, maxs)

- Getting mins/maxs


HBox(children=(FloatProgress(value=0.0, max=1351.0), HTML(value='')))


-- Computing the values...
- Getting the gaps
- Getting test stdevs


HBox(children=(FloatProgress(value=0.0, max=169.0), HTML(value='')))


- Getting test stdevs


HBox(children=(FloatProgress(value=0.0, max=169.0), HTML(value='')))




In [17]:
# Releasing the GPU cache memory
torch.cuda.empty_cache()

# Testing

In [18]:
print("Skins dermoscopy")
skin_derm_stdev = get_dev_scores_per_label(skin_derm, power, mins, maxs)
skin_derm_results = detect_mean(sk_test_stdev, skin_derm_stdev)

Skins dermoscopy


HBox(children=(FloatProgress(value=0.0, max=105.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 72.773 94.040 87.863 93.462 91.418


In [19]:
print("Skins clinical")
skin_cli_stdev = get_dev_scores_per_label(skin_cli, power, mins, maxs)
skin_cli_results = detect_mean(sk_test_stdev, skin_cli_stdev)

Skins clinical


HBox(children=(FloatProgress(value=0.0, max=49.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 83.817 96.352 90.997 98.483 88.027


In [20]:
print("ImageNet")
imgnet_stdev = get_dev_scores_per_label(imgnet, power, mins, maxs)
imgent_results = detect_mean(sk_test_stdev, imgnet_stdev)

ImageNet


HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 92.420 98.458 94.362 98.387 98.426


In [21]:
print("Corrupted images bbox")
corrupted_stdev = get_dev_scores_per_label(corrupted, power, mins, maxs)
corrupted_results = detect_mean(sk_test_stdev, corrupted_stdev)

Corrupted images bbox


HBox(children=(FloatProgress(value=0.0, max=139.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 98.742 98.755 97.052 99.192 97.093


In [22]:
print("Corrupted images bbox 70")
corrupted_70_stdev = get_dev_scores_per_label(corrupted_70, power, mins, maxs)
corrupted_70_results = detect_mean(sk_test_stdev, corrupted_70_stdev)

Corrupted images bbox 70


HBox(children=(FloatProgress(value=0.0, max=164.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 100.000 99.886 99.483 99.909 99.682


In [23]:
print("NCT")
nct_stdev = get_dev_scores_per_label(nct, power, mins, maxs)
nct_results = detect_mean(sk_test_stdev, nct_stdev)

NCT


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
 100.000 99.739 98.898 99.854 99.241


## Summary

In [25]:
print(round(skin_derm_results['TNR']*100,3))
print(round(skin_cli_results['TNR']*100,3))
print(round(imgent_results['TNR']*100,3))
print(round(corrupted_results['TNR']*100,3))
print(round(corrupted_70_results['TNR']*100,3))
print(round(nct_results['TNR']*100,3))

72.773
83.817
92.42
98.742
100.0
100.0
10.41
