In [1]:
from __future__ import division,print_function

%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')
from tqdm.notebook import tqdm
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models, datasets
import utils.calculate_log as callog
from my_models import vgg
import utils.build_dataset as bd
import pandas as pd

# Setting the model

In [2]:
method = "gram-ood*" # or 'gram-ood'
torch_model = vgg.Net(models.vgg16_bn(pretrained=False), 8)
ckpt = torch.load("../checkpoints/vgg-16_checkpoint.pth")
torch_model.load_state_dict(ckpt['model_state_dict'])
torch_model.eval()
torch_model.cuda()
print("Done!")



Done!


## Setting the hook register

In [3]:
feat_maps = list()
def _hook_fn(self, input, output):
    feat_maps.append(output)
    

def hook_layers(model):
    hooked_layers = list()
    for layer in torch_model.modules():
        if method == 'gram-ood*':
            if isinstance(layer, nn.BatchNorm2d):            
                hooked_layers.append(layer)
        else:
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ReLU): 
                hooked_layers.append(layer)
                
    return hooked_layers



def register_layers(layers):
    regs_layers = list()
    for lay in layers:
        regs_layers.append(lay.register_forward_hook(_hook_fn))
    return regs_layers


def unregister_layers(reg_layers):
    for lay in reg_layers:
        lay.remove()
                    

def get_feat_maps(model, batch_img):
    batch_img = batch_img.cuda()
    with torch.no_grad():
        preds = model(batch_img)

    preds = F.softmax(preds, dim=1)
    maps = feat_maps.copy()
    feat_maps.clear()
    return preds, maps

## Setting the hook
hl = hook_layers (torch_model)
rgl = register_layers (hl)
print ("Total number of registered hooked layers:", len(rgl))

Total number of registered hooked layers: 28


# Loading the data

## In distributions

In [4]:
batch_size = 15
trans = transforms.Compose([
#             transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])


sk_train = torch.utils.data.DataLoader(
                datasets.ImageFolder("../data/skin_cancer/train/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

sk_test = torch.utils.data.DataLoader(
                datasets.ImageFolder("../data/skin_cancer/test/",transform=trans),
                batch_size=batch_size,
                shuffle=False)

## Out-of-distributions

In [5]:
TRANS = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

PARAMS = {
    'batch_size': 15,
    'shuf': False
}

def get_data (df, base_path, img_ext):    
    _imgs_path = df['image'].values
    imgs_path = [os.path.join(base_path, i + img_ext) for i in _imgs_path]
    dl = bd.get_data_loader(imgs_path, None, transform=TRANS, params=PARAMS)    
    return dl

In [6]:
test_csv = pd.read_csv("merge/ISIC_2019_Test_Metadata.csv")
final_test = get_data(test_csv, "../data/final_test/img", ".jpeg")

# Gram-Matrix operations

## Gram matrix operator

In [7]:
def norm_min_max(x):
    ma = torch.max(x,dim=1)[0].unsqueeze(1)
    mi = torch.min(x,dim=1)[0].unsqueeze(1)
    x = (x-mi)/(ma-mi)
    return x

def get_sims_gram_matrix (maps, power):
    maps = F.relu(maps)    
    maps = maps ** power    
    maps = maps.reshape(maps.shape[0],maps.shape[1],-1)
    gram = ((torch.matmul(maps,maps.transpose(dim0=2,dim1=1)))).sum(2)
    gram = (gram.sign()*torch.abs(gram)**(1/power)).reshape(gram.shape[0],-1)  
    
    if method == 'gram-ood*':
        gram = norm_min_max(gram)
        
    return gram


## Considering samples per label

In [8]:
def _get_sim_per_labels(data_loader, power, use_preds=True):
    
    sims_per_label = None
    if not isinstance(power, list) and not isinstance(power, range):
        power = [power]
    
    for data in tqdm(data_loader):
        img_batch, labels = data 
        preds, maps_list = get_feat_maps(torch_model, img_batch)
      
        if use_preds:
            labels = preds.argmax(dim=1)  
                
        if sims_per_label is None:
            sims_per_label = [[[] for _ in range(len(maps_list))] for _ in range(preds.shape[1])]  
           
        for layer, maps in enumerate(maps_list): 
            for p in power:
                sims = get_sims_gram_matrix (maps, p)

                for sim, lab in zip(sims, labels):              
                    sims_per_label[lab.item()][layer].append(sim.cpu()) 
                
    return sims_per_label


def get_min_max_per_label(data_loader, power):
    
    sims_per_label = _get_sim_per_labels(data_loader, power)
    sims_per_label_min = [[[] for _ in range(len(sims_per_label[0]))] for _ in range(len(sims_per_label))] 
    sims_per_label_max = [[[] for _ in range(len(sims_per_label[0]))] for _ in range(len(sims_per_label))] 
    
    
    print ("-- Computing the values...")
    for lab_idx in range(len(sims_per_label)):
        for layer_idx in range(len(sims_per_label[lab_idx])):
            temp = torch.stack(sims_per_label[lab_idx][layer_idx])
            sims_per_label_min[lab_idx][layer_idx] = temp.min(dim=0)[0] 
            sims_per_label_max[lab_idx][layer_idx] = temp.max(dim=0)[0]
    
    del sims_per_label
    
    return sims_per_label_min, sims_per_label_max


def get_dev_scores_per_label_and_name(data_loader, power, sims_min, sims_max, ep=10e-6):
    
    if not isinstance(power, list) and not isinstance(power, range):
        power = [power]
    
    dev_scores = list()   
    img_names = list()
    for data in tqdm(data_loader):
        img_batch, _, _, img_name = data 
        preds_batch, maps_list = get_feat_maps(torch_model, img_batch)                
        labels = preds_batch.argmax(dim=1)
        batch_scores = list()
       
        for layer, maps in enumerate(maps_list):
                
            score_layer = 0
            for p in power:
                sims = get_sims_gram_matrix (maps, p)  
                _sim_min = torch.zeros(sims.shape[0], sims.shape[1]).cuda()
                _sim_max = torch.zeros(sims.shape[0], sims.shape[1]).cuda()
            
                for k, lab in enumerate(labels):
                    _sim_min[k] = sims_min[lab.item()][layer]
                    _sim_max[k] = sims_max[lab.item()][layer]            
            
                score_layer += (F.relu(_sim_min-sims)/torch.abs(_sim_min+ep)).sum(dim=1, keepdim=True)
                score_layer += (F.relu(sims-_sim_max)/torch.abs(_sim_max+ep)).sum(dim=1, keepdim=True)
           
            batch_scores.append(score_layer)            
            
        dev_scores.append(torch.cat(batch_scores, dim=1)) 
        img_names.append(img_name) 

    return torch.cat(dev_scores).cpu().numpy(), img_names


def get_dev_scores_per_label(data_loader, power, sims_min, sims_max, ep=10e-6):
    
    if not isinstance(power, list) and not isinstance(power, range):
        power = [power]
    
    dev_scores = list()   
    for data in tqdm(data_loader):
        img_batch, _ = data 
        preds_batch, maps_list = get_feat_maps(torch_model, img_batch)                
        labels = preds_batch.argmax(dim=1)
        batch_scores = list()
       
        for layer, maps in enumerate(maps_list):
                
            score_layer = 0
            for p in power:
                sims = get_sims_gram_matrix (maps, p)  
                _sim_min = torch.zeros(sims.shape[0], sims.shape[1]).cuda()
                _sim_max = torch.zeros(sims.shape[0], sims.shape[1]).cuda()
            
                for k, lab in enumerate(labels):
                    _sim_min[k] = sims_min[lab.item()][layer]
                    _sim_max[k] = sims_max[lab.item()][layer]            
            
                score_layer += (F.relu(_sim_min-sims)/torch.abs(_sim_min+ep)).sum(dim=1, keepdim=True)
                score_layer += (F.relu(sims-_sim_max)/torch.abs(_sim_max+ep)).sum(dim=1, keepdim=True)
           
            batch_scores.append(score_layer)            
            
        dev_scores.append(torch.cat(batch_scores, dim=1)) 

    return torch.cat(dev_scores).cpu().numpy()

In [9]:
def detect_mean(all_test_std, all_ood_std, gaps=None): 
    
    avg_results = dict()
    indices = list(range(len(all_test_std)))
    split = int(np.floor(0.1 * len(all_test_std))) 
    for i in range(1,11):
        np.random.seed(i)
        np.random.shuffle(indices)
        
        val_std = all_test_std[indices[:split]]
        test_std = all_test_std[indices[split:]]
        
        if gaps is not None:
            t95 = (val_std.sum(axis=0) + gaps.mean(0))
        else:
            t95 = val_std.mean(axis=0) + 10**-7
        
        test_std = ((test_std)/t95[np.newaxis,:]).sum(axis=1)
        ood_std = ((all_ood_std)/t95[np.newaxis,:]).sum(axis=1)

        results = callog.compute_metric(-test_std,-ood_std)  

        for m in results:
            avg_results[m] = avg_results.get(m,0)+results[m]
    
    for m in avg_results:
        avg_results[m] /= i
        
        
    callog.print_results(avg_results)
    
    return avg_results

In [10]:
def detect(all_test_std, all_ood_std):     
    
    indices = list(range(len(all_test_std)))
    split = int(np.floor(0.1 * len(all_test_std))) 
    np.random.seed(10)
    np.random.shuffle(indices)
        
    val_std = all_test_std[indices[:split]]
    test_std = all_test_std[indices[split:]]
        
    t95 = val_std.mean(axis=0) + 10**-7
        
    test_std = ((test_std)/t95[np.newaxis,:]).sum(axis=1)
    ood_std = ((all_ood_std)/t95[np.newaxis,:]).sum(axis=1)

    results = callog.compute_metric(-test_std,-ood_std)  

    callog.print_results(results)
    
    return results, ood_std

# OOD detection per label

In [11]:
if method == 'gram-ood*':
    power = 1
else:
    power = range(1,10)

print ("- Getting mins/maxs")
mins, maxs = get_min_max_per_label(sk_train, power)

print ("- Getting test stdevs")
sk_test_stdev = get_dev_scores_per_label(sk_test, power, mins, maxs)

- Getting mins/maxs


HBox(children=(FloatProgress(value=0.0, max=1351.0), HTML(value='')))


-- Computing the values...
- Getting the gaps
- Getting test stdevs


HBox(children=(FloatProgress(value=0.0, max=169.0), HTML(value='')))




In [12]:
# Releasing the GPU memory
torch.cuda.empty_cache()

# Testing

In [13]:
print("Final test")
final_test_stdev, names = get_dev_scores_per_label_and_name(final_test, power, mins, maxs)
final_test_results, final = detect(sk_test_stdev, final_test_stdev)

Final test


HBox(children=(FloatProgress(value=0.0, max=550.0), HTML(value='')))


 TNR    AUROC  DTACC  AUIN   AUOUT 
  8.339 44.685 54.574 18.481 79.352


## Saving deviations and image names

In [14]:
np.savetxt('results/{}/final_vgg'.format(method), final_test_stdev.sum(axis=1), fmt='%.3f')
np.savetxt('results/{}/test_vgg'.format(method), sk_test_stdev.sum(axis=1), fmt='%.3f')

In [16]:
clean_names = list()
for sub in names:
    for s in sub:
        clean_names.append(s)    
np.savetxt('results/{}/names_vgg'.format(method), clean_names, fmt="%s")