# ViT vs CNN

In [None]:
import glob
from typing import List

import BiT
from ViT import modeling as ViT

from PIL import Image
from PIL.Image import Image as Img
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import pickle

import numpy as np

import torch.nn as nn
import torch
import torchvision.transforms as T

In [None]:
dataset_path = 'data/cats'

In [None]:
def get_weights(path):
  return np.load(path)

## Prepare dataset and model

In [None]:
IMG_SIZE = (384, 384)
NORMALIZE_MEAN = (0.5, 0.5, 0.5)
NORMALIZE_STD = (0.5, 0.5, 0.5)
transforms = [
              T.Resize(IMG_SIZE),
              T.ToTensor(),
              T.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
              ]

transforms = T.Compose(transforms)

In [None]:
def load_images(dataset_path: str) -> List[Img]:
    images = []
    for filename in glob.glob(dataset_path + '/*.jpg'):
        im=Image.open(filename).convert('RGB')
        images.append(im)
    return images


def load_vit(model_name='ViT-B_16', path='ViT-B_16.npz') -> nn.Module:
    config = ViT.CONFIGS[model_name]
    model = ViT.VisionTransformer(config, num_classes=1000, img_size=384)
    model.load_from(get_weights(path))
    return model.eval()

def load_bit(model_name='BiT-M-R50x3', path='BiT-M-R50x3-ILSVRC2012.npz') -> nn.Module:
    model = BiT.KNOWN_MODELS[model_name](head_size=1000)
    model.load_from(get_weights(path))
    return model

In [None]:
images = load_images(dataset_path)
vit = load_vit()
bit = load_bit()

## Compare predictions

In [None]:
from ipynb.fs.full.perturb_dataset import *

In [None]:
MASK = {
    'BlackBoxMasking_1': BlackBoxMasking(layout=FixedShapeLayout(count=3, shape=(128, 128))),
    'BlackBoxMasking_2': BlackBoxMasking(layout=FixedLayout(masks=[(0.4, 0.4, 0.6, 0.6)])),
    'BlackBoxMasking_3': BlackBoxMasking(layout=FixedRatioLayout(count=20, ratio=1/30)),
    'BlackBoxMasking_4': BlackBoxMasking(layout=GridLayout(perc=0.8, grid_size=8)),
    'MedianMasking' : MedianMasking(layout=FixedShapeLayout(count=30, shape=(32,32))),
    'BlurMasking_1': BlurMasking(layout=FixedShapeLayout(count=5, shape=(100, 100)), blur=16),
    'BlurMasking_2': BlurMasking(layout=GridLayout(perc=0.7, grid_size=8), blur=16),
}

MASK_NO = len(MASK)

In [None]:
def generate_dataset(images: List[Img]) -> Tuple[List[Img], List[torch.Tensor]]:
    out_images, out_raw = [], []
    for img in images:
        out_images.append(transforms(img).unsqueeze(0))
        out_raw.append(img)
        for _, func in MASK.items():
            out_raw.append(func(img))
            out_images.append(transforms(func(img)).unsqueeze(0))
    return out_raw, out_images

In [None]:
img_raw, img_tensors = generate_dataset(images)

In [None]:
pickle.dump(img_raw, open("img_raw.pickle", 'wb'))

In [None]:
pickle.dump(img_tensors, open("img_tensor.pickle", 'wb'))

In [None]:
def batches(img_tensors: List[torch.Tensor], batch_size=32):
    rest = min(1, len(img_tensors) % batch_size)
    for i in range(len(img_tensors) // batch_size + rest):
        yield img_tensors[i * batch_size: (i + 1) * batch_size]

def predict_and_save(model, img_tensors: List[torch.Tensor], target_path: str, batch_size: int = 32):
    outputs = []
    for batch in batches(img_tensors, batch_size):
        img_batch = torch.cat(batch)
        with torch.no_grad():
            output = model(img_batch)
        if isinstance(output, tuple):
            output = output[0]
        outputs.append(output)
    outputs = torch.cat(outputs)
    pickle.dump(outputs.cpu(), open(target_path, 'wb'))

In [None]:
# main computational costs
predict_and_save(vit, img_tensors, target_path='vit_output.pickle', batch_size=32)

In [None]:
# main computational costs 2
predict_and_save(bit, img_tensors, target_path='bit_output.pickle', batch_size=32)

In [None]:
imagenet_labels = dict(enumerate(open('ilsvrc2012_wordnet_lemmas.txt')))

In [None]:
def extract_top_labels_list(outputs: torch.Tensor, top_n) -> List[List[str]]:
    out = []
    for model_out in outputs:
        out.append([imagenet_labels[int(i)] for i in np.argsort(model_out.detach().numpy())][-top_n:])
    return out

In [None]:
forbidden_labels = [
    'tabby, tabby_cat\n',
    'tiger_cat\n',
    'Persian_cat\n',
    'Siamese_cat, Siamese\n',
    'Egyptian_cat\n',
    'cougar, puma, catamount, mountain_lion, painter, panther, Felis_concolor\n',
    'lynx, catamount\n',
    'leopard, Panthera_pardus\n',
    'snow_leopard, ounce, Panthera_uncia\n',
    'jaguar, panther, Panthera_onca, Felis_onca\n',
    'lion, king_of_beasts, Panthera_leo\n',
    'tiger, Panthera_tigris\n',
    'cheetah, chetah, Acinonyx_jubatus\n',
]

In [None]:
def filter_uninteresting(vit_top_labels: List[List[str]], bit_top_labels: List[List[str]], stats_on=True) \
    -> Tuple[List[PIL.Image.Image], List[List[str]], List[List[str]], List[int], List[int]]:
    interesting_images, corresponding_vit_labels, corresponding_bit_label = [], [], []
    stats_vit = [0] * 7
    stats_bit = [0] * 7
    for i in range(0, len(img_raw), MASK_NO+1):
        for j in range(i+1, i+MASK_NO+1, 1):
            vit_default_pred = set(vit_top_labels[i]) - set(forbidden_labels)
            bit_default_pred = set(bit_top_labels[i]) - set(forbidden_labels)
            vit_new_pred = set(vit_top_labels[j]) - set(forbidden_labels)
            bit_new_pred = set(bit_top_labels[j]) - set(forbidden_labels)

            if (vit_default_pred != vit_new_pred) or (bit_default_pred != bit_new_pred):    
                interesting_images.append(img_raw[j])
                corresponding_vit_labels.append((vit_top_labels[i], vit_top_labels[j]))
                corresponding_bit_label.append((bit_top_labels[i], bit_top_labels[j]))
                
            if stats_on:
                if (vit_default_pred != vit_new_pred) : stats_vit[(j-1)%MASK_NO] += 1
                if (bit_default_pred != bit_new_pred) : stats_bit[(j-1)%MASK_NO] += 1

    return interesting_images, corresponding_vit_labels, corresponding_bit_label, stats_bit, stats_vit

In [None]:
def print_predict_compare(img, vit_labels, bit_labels):
    fig = plt.figure(figsize=(12, 6))
    gs = GridSpec(nrows=3, ncols=3)
    
    labels_0 = ''.join(str(e) for e in bit_labels[0])
    labels_1 = ''.join(str(e) for e in vit_labels[0])
    labels_2 = ''.join(str(e) for e in bit_labels[1])
    labels_3 = ''.join(str(e) for e in vit_labels[1])

    ax0 = fig.add_subplot(gs[1, 0])
    ax0.axis("off")
    ax0.invert_yaxis()
    ax0.text(0.1, 0.25, labels_0, verticalalignment="center")
    ax0.set_title("BiT default prediction")
    
    ax1 = fig.add_subplot(gs[2, 0])
    ax1.axis("off")
    ax1.invert_yaxis()
    ax1.text(0.1, 0.25, labels_1, verticalalignment="center")
    ax1.set_title("ViT default prediction")
    
    ax2 = fig.add_subplot(gs[1, 1])
    ax2.axis("off")
    ax2.invert_yaxis()
    ax2.text(0.1, 0.25, labels_2, verticalalignment="center")
    ax2.set_title("BiT new prediction")

    ax3 = fig.add_subplot(gs[2, 1])
    ax3.axis("off")
    ax3.invert_yaxis()
    ax3.text(0.1, 0.25, labels_3, verticalalignment="center")
    ax3.set_title("ViT new prediction")

    ax4 = fig.add_subplot(gs[:, 2])
    ax4.imshow(img)
    ax4.axis("off")
    
    plt.tight_layout()
    plt.show()

In [None]:
def show_all_predictions(interesting_images, corresponding_vit_labels, corresponding_bit_labels):
    for idx, img in enumerate(interesting_images):
        print_predict_compare(img, corresponding_vit_labels[idx], corresponding_bit_labels[idx])

In [None]:
vit_outputs = pickle.load(open('vit_output.pickle', 'rb'))
bit_outputs = pickle.load(open('bit_output.pickle', 'rb'))

In [None]:
vit_predictions = extract_top_labels_list(vit_outputs, 3)
bit_predictions = extract_top_labels_list(bit_outputs, 3)
images_fail, vit_fail, bit_fail, stats_bit, stats_vit = filter_uninteresting(vit_predictions, bit_predictions)
show_all_predictions(images_fail, vit_fail, bit_fail)

In [None]:
img_size = len(images)

In [None]:
for idx, mask in enumerate([*MASK.keys()]):
    print("{:s} BIT failure rate: {:.2f} VIT failure rate: {:.2f}".format(
        mask, float(stats_bit[idx]/img_size), float(stats_vit[idx]/img_size)))

## Compare activation maps

In [None]:
# TODO

## Compare embeddings clusters

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from scipy.spatial import ConvexHull
from sklearn.cluster import AgglomerativeClustering
from sklearn.manifold import MDS

In [None]:
def draw_clustering(embeddings, n_clusters, r=1.5):
    clustering = AgglomerativeClustering(n_clusters).fit(np.array(embeddings))
    colors = [np.random.rand(3,) for _ in range(n_clusters)]
    labels = clustering.labels_
    embeddings_2d = MDS().fit_transform(embeddings)
    
    fig, ax = plt.subplots(figsize=(17, 17))
    
    for i in range(embeddings_2d.shape[0]):
        im = OffsetImage(images[i], zoom=0.1)
        ab = AnnotationBbox(im, embeddings_2d[i], xycoords='data', frameon=False)
        ax.add_artist(ab)

    for i in range(n_clusters):
        index = labels == i
        points = []
        for pt in embeddings_2d[index]:
            points.append(pt)
            points.append(pt + np.array((1, 1)) * r)
            points.append(pt + np.array((1, -1)) * r)
            points.append(pt + np.array((-1, 1)) * r)
            points.append(pt + np.array((-1, -1)) * r)
        points = np.array(points)
        hull = ConvexHull(points)
        hull_points = points[hull.vertices]
        x_hull = np.append(hull_points[:, 0], hull_points[:, 0][0])
        y_hull = np.append(hull_points[:, 1], hull_points[:, 1][0])
        plt.fill(x_hull, y_hull, alpha=0.3, c=colors[i])
        
    ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1])
    plt.show()

In [None]:
images = images_interesting
img_tensors = [transforms(img).unsqueeze(0) for img in images]
img_batch = torch.cat(img_tensors)
assert img_batch.shape[0] > 1

In [None]:
with torch.no_grad():
    vit_embeddings = vit.transformer(img_batch)[0][:, 0, :]

In [None]:
base = list(bit.children())[:-1]
pooling = list(bit.head.children())[:-1]
bit_pruned = nn.Sequential(*base, *pooling)
with torch.no_grad():
    bit_embeddings = bit_pruned(img_batch).squeeze()

In [None]:
draw_clustering(vit_embeddings, 8)

In [None]:
draw_clustering(bit_embeddings, 8, r=8)