In [36]:
import numpy as np
import torch
from torch import Tensor
from torch import nn
from torch.functional import F
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import os
from typing import Dict, Tuple, Union, NewType, List, Optional, Any
from pathlib import Path, WindowsPath
import warnings
import pickle
warnings.filterwarnings("ignore")
import seaborn as sns

def load_obj(path: str) -> Any:
    with open(path, 'rb') as f:
        return pickle.load(f)
    

def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

def visu(original_image, transformer_attribution, file_name: str):
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (
            transformer_attribution.max() - transformer_attribution.min())
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (
            image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    plt.imsave(fname=Path('plots', f'{file_name}.png'),
               arr=vis,
               format='png')
    
from torchvision import transforms
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
normalize = transforms.Normalize(mean=IMAGENET_STANDARD_MEAN, std=IMAGENET_STANDARD_STD)
image_transformations = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    normalize,
])

In [3]:
tokens_mask_path = r"C:\Users\asher\OneDrive\Documents\Data Science Degree\Tesis\Explainability NLP\explainablity-transformer\research\plots\verify_vis_mul_temp_softmax_lr0_003+l1_0+kl_loss_0+entropy_loss_1000+pred_loss_10\0000000001\objects\tokens_mask.pkl"
mask = load_obj(tokens_mask_path)[-1]
image = Image.open(r"C:\Users\asher\OneDrive\Documents\Data Science Degree\Tesis\Explainability NLP\explainablity-transformer\research\plots\vis_mul_temp_softmax_lr0_003+l1_1000+kl_loss_0+entropy_loss_100+pred_loss_10\0000000001\224x224.JPEG")
image = image_transformations(image)
visu(image, mask.median(dim=0)[0], 'test')

  "See the documentation of nn.Upsample for details.".format(mode)


In [4]:
image_bull_PIL = Image.open(r"C:\Users\asher\OneDrive\Documents\Data Science Degree\Tesis\Explainability NLP\explainablity-transformer\research\notebooks\Bullmastiff-standing-in-a-field.jpg")
image_bull = image_transformations(image_bull_PIL)

In [5]:
image_bull

tensor([[[ 0.7961,  0.7961,  0.7961,  ...,  0.7569,  0.7569,  0.7412],
         [ 0.7961,  0.7961,  0.7961,  ...,  0.7725,  0.7725,  0.7569],
         [ 0.7961,  0.7961,  0.7961,  ...,  0.7725,  0.7804,  0.7569],
         ...,
         [ 0.0353,  0.0745,  0.1216,  ...,  0.3412,  0.2784,  0.1529],
         [ 0.0745, -0.0118,  0.1529,  ...,  0.3098,  0.2471,  0.1608],
         [ 0.1294,  0.0039,  0.1686,  ...,  0.3098,  0.2314,  0.1843]],

        [[ 0.7882,  0.7882,  0.7882,  ...,  0.7647,  0.7647,  0.7490],
         [ 0.7882,  0.7882,  0.7882,  ...,  0.7804,  0.7804,  0.7647],
         [ 0.7882,  0.7882,  0.7882,  ...,  0.7804,  0.7882,  0.7647],
         ...,
         [ 0.0824,  0.1686,  0.1686,  ...,  0.2157,  0.2000,  0.1373],
         [ 0.1216,  0.0824,  0.1922,  ...,  0.1843,  0.1922,  0.1686],
         [ 0.1765,  0.0980,  0.2078,  ...,  0.1922,  0.1843,  0.2000]],

        [[ 0.8353,  0.8353,  0.8353,  ...,  0.8039,  0.8039,  0.7882],
         [ 0.8353,  0.8353,  0.8353,  ...,  0

In [6]:
bull_cls = load_obj(r"C:\Users\asher\OneDrive\Documents\Data Science Degree\Tesis\Explainability NLP\explainablity-transformer\research\notebooks\bull_cls_attentions_probs.pkl")

In [7]:
bull_cls.shape

torch.Size([12, 196])

In [8]:
mask.shape

torch.Size([12, 196])

In [9]:
(mask * bull_cls).shape

torch.Size([12, 196])

In [10]:
mask.median(dim=0)[0]

tensor([5.1167e-06, 1.3258e-03, 1.5191e-03, 2.2338e-03, 2.5690e-03, 1.4327e-04,
        4.4748e-03, 1.8667e-03, 3.7485e-03, 4.0711e-03, 4.1734e-03, 1.3381e-04,
        4.2138e-06, 2.2356e-04, 2.7405e-05, 1.0963e-02, 1.9028e-03, 9.1886e-04,
        2.6397e-03, 2.8667e-04, 3.4716e-03, 6.6686e-03, 3.5860e-03, 8.5304e-03,
        4.1299e-03, 5.8286e-04, 2.2540e-05, 3.8895e-05, 9.6280e-03, 1.2460e-02,
        1.0731e-02, 1.2023e-04, 4.6251e-05, 9.1607e-06, 6.6617e-05, 1.2290e-02,
        4.5248e-03, 7.2795e-03, 4.4460e-03, 8.3014e-03, 1.4466e-03, 6.7513e-05,
        2.0156e-03, 4.5425e-06, 1.2204e-02, 5.3833e-05, 4.0476e-04, 5.0963e-05,
        7.6130e-03, 9.3336e-03, 2.8558e-03, 5.3820e-03, 6.4297e-03, 5.9338e-03,
        3.2779e-03, 3.9443e-05, 1.1195e-03, 1.6390e-03, 3.7824e-05, 8.7411e-05,
        3.7352e-04, 1.4837e-04, 5.3578e-03, 8.2363e-03, 6.2692e-03, 6.0882e-03,
        7.1501e-03, 1.0302e-02, 3.3513e-03, 2.0607e-04, 3.3976e-04, 2.3810e-03,
        4.1092e-06, 2.6392e-04, 1.6034e-

In [11]:
(mask * bull_cls).median(dim=0)[0]

tensor([1.4372e-08, 1.6036e-05, 1.6599e-05, 1.5662e-05, 2.1783e-07, 3.1657e-07,
        8.0065e-05, 4.0315e-05, 9.3275e-07, 1.2967e-04, 4.4635e-07, 9.0479e-09,
        1.1526e-09, 1.1637e-08, 6.0651e-09, 3.4227e-04, 8.0598e-07, 1.3921e-07,
        5.2763e-07, 6.7266e-07, 7.0283e-06, 2.8292e-07, 5.2764e-07, 6.1464e-07,
        3.1010e-07, 3.9373e-08, 3.9110e-09, 7.9542e-09, 8.9594e-07, 9.7303e-07,
        1.1614e-05, 9.9418e-08, 1.7761e-07, 3.7794e-08, 5.1459e-07, 8.4035e-06,
        5.1419e-07, 2.8589e-04, 2.7177e-06, 5.1488e-04, 2.6514e-07, 2.3290e-07,
        5.7418e-08, 8.8875e-10, 1.0182e-05, 1.1630e-07, 1.5020e-06, 3.8147e-07,
        3.7192e-05, 8.0781e-06, 3.3798e-06, 1.2249e-06, 1.2469e-06, 1.8657e-05,
        3.7813e-07, 3.7513e-08, 4.4708e-08, 4.6884e-08, 5.0430e-08, 2.4095e-07,
        8.2264e-07, 4.1638e-07, 8.1182e-06, 2.0013e-05, 1.3590e-05, 6.2216e-06,
        2.5300e-06, 1.3423e-06, 3.3746e-07, 1.5059e-07, 4.9871e-08, 1.3853e-07,
        1.7934e-08, 9.2481e-07, 9.2486e-

In [12]:
mask.median(dim=0)[0] * bull_cls.median(dim=0)[0]

tensor([1.2109e-08, 3.4186e-05, 3.9964e-05, 6.5368e-05, 9.9487e-07, 2.7807e-07,
        1.1461e-04, 6.1947e-05, 5.2614e-06, 1.2163e-04, 2.6738e-06, 5.2555e-09,
        5.2254e-10, 1.4303e-08, 4.7297e-09, 2.9033e-04, 6.0420e-06, 1.4140e-06,
        1.0445e-05, 1.3935e-06, 1.3066e-05, 4.6741e-07, 3.6963e-07, 9.2100e-07,
        4.7377e-07, 6.0449e-08, 3.0769e-09, 6.2539e-09, 1.2075e-06, 1.0388e-06,
        2.5761e-05, 3.2651e-07, 2.9874e-07, 7.6245e-08, 2.5835e-07, 4.0152e-05,
        5.5353e-07, 2.3322e-04, 7.2617e-06, 2.8893e-04, 2.5902e-07, 1.7892e-07,
        8.1941e-08, 3.8999e-10, 3.4390e-05, 2.4769e-07, 2.2271e-06, 3.8104e-07,
        5.5321e-05, 2.7757e-05, 6.8167e-06, 2.2295e-06, 1.0294e-06, 2.0864e-05,
        5.5624e-07, 6.1068e-09, 5.7373e-08, 1.3999e-07, 1.7296e-07, 3.4796e-07,
        2.2834e-06, 6.2970e-07, 3.4723e-05, 3.4721e-05, 8.8647e-06, 4.2083e-06,
        2.2287e-06, 8.9208e-07, 5.2754e-07, 6.0750e-07, 2.3465e-08, 1.5674e-06,
        2.0245e-08, 1.4450e-06, 3.8321e-

In [16]:
mask.median(dim=0)[0].shape

torch.Size([196])

In [22]:
F.cosine_similarity(mask, bull_cls, dim=0)

tensor([0.4970, 0.4711, 0.6054, 0.4840, 0.7468, 0.4163, 0.6187, 0.5531, 0.7864,
        0.6266, 0.4810, 0.5735, 0.4767, 0.4991, 0.5640, 0.3642, 0.7432, 0.5519,
        0.8052, 0.5802, 0.6305, 0.4095, 0.6958, 0.7538, 0.7165, 0.7230, 0.3045,
        0.2954, 0.4629, 0.7471, 0.5300, 0.6243, 0.2851, 0.3179, 0.3755, 0.4331,
        0.7235, 0.6735, 0.5457, 0.5860, 0.6767, 0.6316, 0.4237, 0.6356, 0.4745,
        0.4225, 0.2674, 0.3119, 0.3969, 0.6578, 0.6387, 0.4265, 0.5098, 0.5013,
        0.4141, 0.3722, 0.5265, 0.5468, 0.5178, 0.6694, 0.3573, 0.5265, 0.6950,
        0.6554, 0.5062, 0.5085, 0.3146, 0.7686, 0.3711, 0.4631, 0.4909, 0.5229,
        0.5551, 0.7128, 0.6108, 0.4031, 0.5073, 0.5020, 0.5610, 0.5356, 0.3035,
        0.8496, 0.5766, 0.1968, 0.6636, 0.6843, 0.5210, 0.4154, 0.5841, 0.1929,
        0.4814, 0.6185, 0.3724, 0.5322, 0.8223, 0.5084, 0.3419, 0.5342, 0.6199,
        0.5475, 0.7114, 0.7061, 0.8981, 0.4213, 0.6824, 0.5687, 0.6922, 0.4009,
        0.4789, 0.5667, 0.5461, 0.7017, 

In [37]:
visu(image_bull, F.cosine_similarity(mask.median(dim=0)[0].unsqueeze(0), bull_cls.median(dim=0)[0].unsqueeze(0), dim=0), 'cosine_bull_mask')

In [35]:
visu(image_bull, F.cosine_similarity(mask, bull_cls, dim=0), 'cosine_bull_mask')

  "See the documentation of nn.Upsample for details.".format(mode)


In [34]:
visu(image_bull, F.cosine_similarity(mask.median(dim=0)[0].unsqueeze(0), bull_cls.median(dim=0)[0].unsqueeze(0), dim=0), 'cosine_bull_mask')

  "See the documentation of nn.Upsample for details.".format(mode)


In [38]:
visu(image_bull, mask.median(dim=0)[0] * bull_cls.median(dim=0)[0], 'bull_dot_mask')

  "See the documentation of nn.Upsample for details.".format(mode)
