In [2]:
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import DataLoader
import csv
import random
import numpy as np
import cv2
import mimetypes
import fnmatch
import seaborn as sns
from collections import Counter, defaultdict
import torchvision
from torchvision import transforms
from torch import nn
import torch.optim as optim
import time
from progress.bar import IncrementalBar
from tqdm.notebook import tqdm_notebook
from sklearn.metrics import confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_auc_score
import io
from pytorch_lightning.callbacks import Callback
from datetime import datetime, date, time
from PIL import Image
import itertools 
from sklearn.model_selection import train_test_split
import torchvision.models as models
from ViT.models.modeling import VisionTransformer, CONFIGS
from pytorch_grad_cam_master.pytorch_grad_cam import GradCAM
from pytorch_grad_cam_master.pytorch_grad_cam.utils.image import show_cam_on_image
import csv
import pandas as pd
import sys

sys.path.insert(0, '/home/anna/Desktop/Diploma/Learning/Sources/')

from callbacks_2classes_x10 import plot_confusion_matrix
from torch.nn import functional as F
from callbacks_2classes_x10 import get_true_classes
from callbacks_2classes_x10 import get_predicted_classes
from callbacks_2classes_x10 import get_classes_probs
from callbacks_2classes_x10 import callback
from callbacks_2classes_x10 import plot_to_image
from data_tools import CatsDataset

from vit_rollout import VITAttentionRollout
from tqdm.notebook import tqdm_notebook

from fpdf import FPDF

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

labels_map = {
    0: "NotCat",
    1: "Cat",
}

transform_for_maps = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize(224),
    transforms.CenterCrop(224),
])

transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = CatsDataset('test_paths.txt', transform = transform_for_maps)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

Using cuda device


In [3]:
# VIT
config = CONFIGS["ViT-B_16"]
vit = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
vit.head = nn.Linear(768, 1)
vit.load_state_dict(torch.load("../../Logits/SavedNN/Saved_ViT_B_16_cats/" + str(6)))
vit.to(device)
vit.eval()

# RESNET
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(512, 1)
resnet.load_state_dict(torch.load("../../Logits/SavedNN/Saved_ResNet18_cats/" + str(11)))
resnet.to(device)
resnet.eval()
None

In [4]:
# ORIGINAL
def get_original_image(input_tensor):
    return input_tensor[0].squeeze().permute(1, 2, 0)

# RESNET
def get_resnet_gradcam_map(input_tensor, cam, pred_label):
    
    target_category = pred_label
    # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
    grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
    # In this example grayscale_cam has only one image in the batch:
    grayscale_cam = grayscale_cam[0, :]
    image = input_tensor.cpu().detach() 
    image = image.squeeze().permute(1, 2, 0).numpy()
    visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True)
    
    return visualization, grayscale_cam

def get_vit_rollout_map(input_tensor, rollout, pred_label):

    mask = rollout(input_tensor)
    mask = mask / np.max(mask)
    
    image = input_tensor.cpu().detach() 
    image = image.squeeze().permute(1, 2, 0).numpy()
    mask = cv2.resize(mask, (224, 224))[..., np.newaxis]
    result = show_cam_on_image(image, mask, use_rgb=True)
    
    return result, mask.squeeze()
    

methods_dict = {
    "original": get_original_image,
    "resnet_gradcam": get_resnet_gradcam_map,
    "vit_rollout": get_vit_rollout_map,
    "vit_rollout_min": get_vit_rollout_map,
    "vit_rollout_mean": get_vit_rollout_map,
    "vit_rollout_max": get_vit_rollout_map,
}

In [5]:
# ResNet CAM object
target_layers = [resnet.layer4[-1]]
resnet_cam = GradCAM(resnet, target_layers)

# Rollout object
rollout = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='mean')
# rollout_min = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='min')
# rollout_mean = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='mean')
# rollout_max = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='max')

In [6]:
methods = [
           "original", 
           "resnet_gradcam",
           "vit_rollout"
          ]

In [8]:
# functions
def get_plot_and_labels(data, methods):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data
    inputs_for_networks = transform(inputs).to(device)
    inputs = inputs.to(device)
    
    real_label = labels[0].item()
#     real_label = labels_map[label]
    
    with torch.no_grad():
        
        vit_prob = vit(inputs_for_networks)[0]
        resnet_prob = resnet(inputs_for_networks)
            
    resnet_prob = torch.sigmoid(resnet_prob)
#     print(resnet_prob)
    resnet_prob = float(resnet_prob.cpu().detach())
    if resnet_prob >= 0.5:
        resnet_label = 1
    else:
        resnet_label = 0
        
    vit_prob = torch.sigmoid(vit_prob)
    vit_prob = float(vit_prob.cpu().detach())
    if vit_prob >= 0.5:
        vit_label = 1
    else:
        vit_label = 0
        
    
    if k_dict[labels_map[int(vit_label)]][labels_map[int(resnet_label)]] >= 50:
        return np.int32(0), 0, labels_map[int(resnet_label)], labels_map[int(vit_label)], 0
    
    for method_i, method in enumerate(methods):
        if method == "original":
            visualization = methods_dict[method](inputs.cpu())
        else:
            if method == "resnet_gradcam": 
                visualization, map1 = methods_dict[method](inputs, 
                                                     resnet_cam, 
                                                     resnet_label)
            elif method == "vit_rollout":
                visualization, map2 = methods_dict[method](inputs, 
                                                     rollout,
                                                     vit_label)
        
        ax[method_i].get_xaxis().set_visible(False)
        ax[method_i].get_yaxis().set_visible(False)
        ax[method_i].set_title(method)
        ax[method_i].imshow(visualization)
        
    del inputs_for_networks
    
    resnet_label = labels_map[int(resnet_label)]
    vit_label = labels_map[int(vit_label)]
    
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    plt.savefig("test.jpg", bbox_inches='tight', pad_inches=0)
    plt.cla()
    img_to_save = Image.open("test.jpg")
    img_to_save = img_to_save.convert("RGB")
    
    siou = siou_for_2_maps(map1, map2)
    
    del inputs
    torch.cuda.empty_cache()
    
    return img_to_save, real_label, resnet_label, vit_label, siou

def save_maps_to_folders(img_to_save, 
                                 real_label, vit_label, resnet_label, 
                                 k_dict, siou):
    # choose the folder
    if vit_label == "Cat":
        if resnet_label == "Cat":
            folder = "Maps_for_comparing/ViT_Cat_ResNet_Cat"
        else:
            folder = "Maps_for_comparing/ViT_Cat_ResNet_NotCat"
    else:
        if resnet_label == "Cat":
            folder = "Maps_for_comparing/ViT_NotCat_ResNet_Cat"
        else:
            folder = "Maps_for_comparing/ViT_NotCat_ResNet_NotCat"
    
#     folder = "Maps"
    
    # save img to folder
    if k_dict[vit_label][resnet_label] >= 50:
        return 1
    
    # resnet or vit
    img_path = os.path.join(folder,
                            np.str(k_dict[vit_label][resnet_label]) + 
                            '_' + str(real_label) + '.jpg')
    cv2.imwrite(img_path, cv2.cvtColor(np.array(img_to_save), cv2.COLOR_RGB2BGR))
            
    k_dict[vit_label][resnet_label] += 1
    
    # dict with info for csv
    
    real_label = labels_map[int(real_label)]
    
    dict_item = {
        'path': img_path,
        'real_label': real_label, 
        'vit_label': vit_label, 
        'resnet_label': resnet_label, 
        'sIoU': siou
    }
    
    rows.append(dict_item)
    
    return 0
        
def siou_for_2_maps(map1, map2):
    return 2 * np.sum(np.minimum(map1, map2)) / np.sum(np.add(map1, map2))

In [10]:
k_dict = {
    "Cat": {
        "Cat": 0,
        "NotCat": 0,
    },
    "NotCat": {
        "Cat": 0,
        "NotCat": 0,
    },
}

real_labels = {
    "Cat": {
        "Cat": [],
        "NotCat": [],
    },
    "NotCat": {
        "Cat": [],
        "NotCat": [],
    },
}

rows = []

plt.rcParams['font.size'] = '20'
# fig, ax = plt.subplots(2, 3, figsize=(24, 16))
fig, ax = plt.subplots(1, 3, figsize=(24, 8))
fig.tight_layout()

i = 0
for data in tqdm_notebook(test_dataloader, desc='gettnig val maps'):
    if i >= 0:
        
        img_to_save, real_label, resnet_label, vit_label, siou = get_plot_and_labels(data, methods)
    
        flag = save_maps_to_folders(img_to_save, 
                                 real_label, vit_label, resnet_label, 
                                 k_dict, siou)
    
        if sum(k_dict["Cat"].values()) + sum(k_dict["NotCat"].values()) >= 200:
            break
            
    i += 1
            
plt.close(fig)

gettnig val maps:   0%|          | 0/2328 [00:00<?, ?it/s]

In [11]:
fieldnames = ['path', 'real_label', 'vit_label', 'resnet_label', 'sIoU']

with open('maps_info.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(rows)

In [3]:
data = pd.read_csv('maps_info.csv')

In [4]:
data.head(15)

Unnamed: 0,path,real_label,vit_label,resnet_label,sIoU
0,Maps_for_comparing/ViT_Cat_ResNet_NotCat/0_1.jpg,Cat,Cat,NotCat,0.374538
1,Maps_for_comparing/ViT_Cat_ResNet_Cat/0_1.jpg,Cat,Cat,Cat,0.267356
2,Maps_for_comparing/ViT_NotCat_ResNet_NotCat/0_...,NotCat,NotCat,NotCat,0.401296
3,Maps_for_comparing/ViT_Cat_ResNet_Cat/1_1.jpg,Cat,Cat,Cat,0.291218
4,Maps_for_comparing/ViT_Cat_ResNet_Cat/2_1.jpg,Cat,Cat,Cat,0.34394
5,Maps_for_comparing/ViT_Cat_ResNet_Cat/3_1.jpg,Cat,Cat,Cat,0.173019
6,Maps_for_comparing/ViT_Cat_ResNet_Cat/4_1.jpg,Cat,Cat,Cat,0.370994
7,Maps_for_comparing/ViT_NotCat_ResNet_NotCat/1_...,NotCat,NotCat,NotCat,0.328667
8,Maps_for_comparing/ViT_NotCat_ResNet_NotCat/2_...,NotCat,NotCat,NotCat,0.270953
9,Maps_for_comparing/ViT_Cat_ResNet_Cat/5_1.jpg,Cat,Cat,Cat,0.275878


In [9]:
data["sIoU"].mean()

0.3328564053436665

In [10]:
siou = pd.DataFrame(data.sIoU)
siou.head(15)

Unnamed: 0,sIoU
0,0.374538
1,0.267356
2,0.401296
3,0.291218
4,0.34394
5,0.173019
6,0.370994
7,0.328667
8,0.270953
9,0.275878


In [6]:
data

0      0.374538
1      0.267356
2      0.401296
3      0.291218
4      0.343940
         ...   
195    0.441338
196    0.175820
197    0.338234
198    0.185934
199    0.363649
Name: sIoU, Length: 200, dtype: float64

In [13]:
data["sIoU"].max()

0.541433500667023

In [17]:
print(data["sIoU"][data["path"] == 'Maps_for_comparing/ViT_Cat_ResNet_Cat/8_1.jpg'])

19    0.355805
Name: sIoU, dtype: float64
