In [1]:
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import DataLoader
import csv
import random
import numpy as np
import cv2
import mimetypes
import fnmatch
import seaborn as sns
from collections import Counter, defaultdict
import torchvision
from torchvision import transforms
from torch import nn
import torch.optim as optim
import time
from progress.bar import IncrementalBar
from tqdm.notebook import tqdm_notebook
from sklearn.metrics import confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_auc_score
import io
from pytorch_lightning.callbacks import Callback
from datetime import datetime, date, time
from PIL import Image
import itertools 
from sklearn.model_selection import train_test_split
import torchvision.models as models
from ViT.models.modeling import VisionTransformer, CONFIGS
from pytorch_grad_cam_master.pytorch_grad_cam import GradCAM
from pytorch_grad_cam_master.pytorch_grad_cam.utils.image import show_cam_on_image
import csv
import pandas as pd
import sys

sys.path.insert(0, '/home/anna/Desktop/Diploma/Learning/Sources/')

from callbacks_2classes_x10 import plot_confusion_matrix
from torch.nn import functional as F
from callbacks_2classes_x10 import get_true_classes
from callbacks_2classes_x10 import get_predicted_classes
from callbacks_2classes_x10 import get_classes_probs
from callbacks_2classes_x10 import callback
from callbacks_2classes_x10 import plot_to_image
from data_tools import CatsDataset

from vit_rollout import VITAttentionRollout
from tqdm.notebook import tqdm_notebook

from fpdf import FPDF

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

labels_map = {
    0: "NotCat",
    1: "Cat",
}

transform_for_maps = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize(224),
    transforms.CenterCrop(224),
])

transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = CatsDataset('test_paths.txt', transform = transform_for_maps)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

Using cuda device


In [3]:
# VIT
config = CONFIGS["ViT-B_16"]
vit = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
vit.head = nn.Linear(768, 1)
vit.load_state_dict(torch.load("../../Logits/SavedNN/Saved_ViT_B_16_cats/" + str(6)))
vit.to(device)
vit.eval()

# RESNET
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(512, 1)
resnet.load_state_dict(torch.load("../../Logits/SavedNN/Saved_ResNet18_cats/" + str(11)))
resnet.to(device)
resnet.eval()
None

In [4]:
# ORIGINAL
def get_original_image(input_tensor):
    return input_tensor[0].squeeze().permute(1, 2, 0)

# RESNET
def get_resnet_gradcam_map(input_tensor, cam, pred_label):
    
    target_category = pred_label
    # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
    grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
    # In this example grayscale_cam has only one image in the batch:
    grayscale_cam = grayscale_cam[0, :]
    image = input_tensor.cpu().detach() 
    image = image.squeeze().permute(1, 2, 0).numpy()
    visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True)
    
    return visualization

def get_vit_rollout_map(input_tensor, rollout, pred_label):

    mask = rollout(input_tensor)
    mask = mask / np.max(mask)
    
    image = input_tensor.cpu().detach() 
    image = image.squeeze().permute(1, 2, 0).numpy()
    mask = cv2.resize(mask, (224, 224))[..., np.newaxis]
    result = show_cam_on_image(image, mask, use_rgb=True)
    
    return result
    

methods_dict = {
    "original": get_original_image,
    "resnet_gradcam": get_resnet_gradcam_map,
    "vit_rollout": get_vit_rollout_map,
    "vit_rollout_min": get_vit_rollout_map,
    "vit_rollout_mean": get_vit_rollout_map,
    "vit_rollout_max": get_vit_rollout_map,
}

In [5]:
# ResNet CAM object
target_layers = [resnet.layer4[-1]]
resnet_cam = GradCAM(resnet, target_layers)

# Rollout object
rollout = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='mean')
# rollout_min = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='min')
# rollout_mean = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='mean')
# rollout_max = VITAttentionRollout(vit, discard_ratio=0.9 ,head_fusion='max')

In [6]:
methods = [
           "original", 
           "resnet_gradcam",
           "vit_rollout"
          ]

In [19]:
# functions
def get_plot_and_labels(data, methods):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data
    inputs_for_networks = transform(inputs).to(device)
    inputs = inputs.to(device)
    
    real_label = labels[0].item()
#     real_label = labels_map[label]
    
    with torch.no_grad():
        
        vit_prob = vit(inputs_for_networks)[0]
        resnet_prob = resnet(inputs_for_networks)
            
    resnet_prob = torch.sigmoid(resnet_prob)
#     print(resnet_prob)
    resnet_prob = float(resnet_prob.cpu().detach())
    if resnet_prob >= 0.5:
        resnet_label = 1
    else:
        resnet_label = 0
        
    vit_prob = torch.sigmoid(vit_prob)
    vit_prob = float(vit_prob.cpu().detach())
    if vit_prob >= 0.5:
        vit_label = 1
    else:
        vit_label = 0
        
    
    if k_vit[labels_map[int(vit_label)]] >= 50 and k_resnet[labels_map[int(resnet_label)]] >= 50:
        return np.int32(0), 0, 0, labels_map[int(resnet_label)], labels_map[int(vit_label)]
    
    orig = methods_dict["original"](inputs.cpu())
            
    resnet_map = methods_dict["resnet_gradcam"](inputs,
                                                                resnet_cam,
                                                                resnet_label)
    
    vit_map = methods_dict["vit_rollout"](inputs, rollout, vit_label)
    
    # TODO: сделать plot1, plot2
    
    plt.rcParams['font.size'] = '20'
    fig, ax = plt.subplots(1, 2, figsize=(16, 8))
    fig.tight_layout()
    
    ax[0].get_xaxis().set_visible(False)
    ax[0].get_yaxis().set_visible(False)
    ax[0].set_title("original")
    ax[0].imshow(orig)
    
    ax[1].get_xaxis().set_visible(False)
    ax[1].get_yaxis().set_visible(False)
    ax[1].set_title("resnet_gradcam")
    ax[1].imshow(resnet_map)
        
    del inputs_for_networks
    
    resnet_label = labels_map[int(resnet_label)]
    vit_label = labels_map[int(vit_label)]
    
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    plt.savefig("test.jpg", bbox_inches='tight', pad_inches=0)
    plt.cla()
    resnet_img = Image.open("test.jpg")
    resnet_img = resnet_img.convert("RGB")
    
    plt.close(fig)
    
    # VIT
    
    fig, ax = plt.subplots(1, 2, figsize=(16, 8))
    fig.tight_layout()
    
    ax[0].get_xaxis().set_visible(False)
    ax[0].get_yaxis().set_visible(False)
    ax[0].set_title("original")
    ax[0].imshow(orig)
    
    ax[1].get_xaxis().set_visible(False)
    ax[1].get_yaxis().set_visible(False)
    ax[1].set_title("vit_rollout")
    ax[1].imshow(vit_map)
    
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    plt.savefig("test.jpg", bbox_inches='tight', pad_inches=0)
    plt.cla()
    vit_img = Image.open("test.jpg")
    vit_img = vit_img.convert("RGB")
    
    plt.close(fig)
    
    del inputs
    torch.cuda.empty_cache()
    
    return resnet_img, vit_img, real_label, resnet_label, vit_label

def save_maps_to_folders(resnet_img, vit_img, real_label, vit_label, 
                         resnet_label, k_resnet, k_vit):
    
    # choose the folder
    if vit_label == "Cat":
        vit_folder = "Maps_for_tasks/ViT/Cat"
    else:
        vit_folder = "Maps_for_tasks/ViT/NotCat"
        
    if resnet_label == "Cat":
        resnet_folder = "Maps_for_tasks/ResNet/Cat"
    else:
        resnet_folder = "Maps_for_tasks/ResNet/NotCat"
    
    # save img to folder
    if k_resnet[resnet_label] < 50:
        img_path = os.path.join(resnet_folder,
                            np.str(k_resnet[resnet_label]) + '.jpg')
        cv2.imwrite(img_path, cv2.cvtColor(np.array(resnet_img), cv2.COLOR_RGB2BGR))
        k_resnet[resnet_label] += 1
        real_label = labels_map[int(real_label)]
    
        dict_item = {
            'path': img_path,
            'true_label': real_label, 
            'pred_label': resnet_label, 
        }
    
        rows.append(dict_item)
        
    if k_vit[vit_label] < 50:
        img_path = os.path.join(vit_folder,
                            np.str(k_vit[vit_label]) + '.jpg')
        cv2.imwrite(img_path, cv2.cvtColor(np.array(vit_img), cv2.COLOR_RGB2BGR))
        k_vit[vit_label] += 1
        if type(real_label) == int:
            real_label = labels_map[int(real_label)]
    
        dict_item = {
            'path': img_path,
            'true_label': real_label, 
            'pred_label': vit_label, 
        }
    
        rows.append(dict_item)
    
    return 0

In [20]:
# TODO возможно стоит считать кол-во по другому

k_resnet = {
    "Cat": 0,
    "NotCat": 0,
}

k_vit = {
    "Cat": 0,
    "NotCat": 0,
}

rows = []

i = 0
for data in tqdm_notebook(test_dataloader, desc='gettnig val maps'):
    if i >= 0:
        
        resnet_img, vit_img, real_label, resnet_label, vit_label = get_plot_and_labels(data, methods)
    
        flag = save_maps_to_folders(resnet_img, vit_img, 
                                 real_label, vit_label, resnet_label, 
                                  k_resnet, k_vit)
    
        if sum(k_resnet.values()) + sum(k_vit.values()) >= 200:
            break
            
    i += 1

gettnig val maps:   0%|          | 0/2328 [00:00<?, ?it/s]

In [21]:
fieldnames = ['path', 'true_label', 'pred_label']

with open('maps_for_tasks_info.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(rows)

In [22]:
data = pd.read_csv('maps_for_tasks_info.csv')

In [23]:
data.head(15)

Unnamed: 0,path,true_label,pred_label
0,Maps_for_tasks/ResNet/NotCat/0.jpg,Cat,NotCat
1,Maps_for_tasks/ViT/Cat/0.jpg,Cat,Cat
2,Maps_for_tasks/ResNet/Cat/0.jpg,Cat,Cat
3,Maps_for_tasks/ViT/Cat/1.jpg,Cat,Cat
4,Maps_for_tasks/ResNet/NotCat/1.jpg,NotCat,NotCat
5,Maps_for_tasks/ViT/NotCat/0.jpg,NotCat,NotCat
6,Maps_for_tasks/ResNet/Cat/1.jpg,Cat,Cat
7,Maps_for_tasks/ViT/Cat/2.jpg,Cat,Cat
8,Maps_for_tasks/ResNet/Cat/2.jpg,Cat,Cat
9,Maps_for_tasks/ViT/Cat/3.jpg,Cat,Cat
