# ViT vs CNN

In [44]:
import glob
from typing import List

import BiT
from ViT import modeling as ViT

from PIL import Image
from PIL.Image import Image as Img
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import pickle

import numpy as np

import torch.nn as nn
import torch
import torchvision.transforms as T

In [12]:
dataset_path = 'samples'

In [3]:
def get_weights(path):
  return np.load(path)

## Prepare dataset and model

In [4]:
IMG_SIZE = (384, 384)
NORMALIZE_MEAN = (0.5, 0.5, 0.5)
NORMALIZE_STD = (0.5, 0.5, 0.5)
transforms = [
              T.Resize(IMG_SIZE),
              T.ToTensor(),
              T.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
              ]

transforms = T.Compose(transforms)

In [5]:
def load_images(dataset_path: str) -> List[Img]:
    images = []
    for filename in glob.glob(dataset_path + '/*.jpg'):
        im=Image.open(filename).convert('RGB')
        images.append(im)
    return images


def load_vit(model_name='ViT-B_16', path='ViT-B_16.npz') -> nn.Module:
    config = ViT.CONFIGS[model_name]
    model = ViT.VisionTransformer(config, num_classes=1000, img_size=384)
    model.load_from(get_weights(path))
    return model.eval()

def load_bit(model_name='BiT-M-R50x3', path='BiT-M-R50x3-ILSVRC2012.npz') -> nn.Module:
    model = BiT.KNOWN_MODELS[model_name](head_size=1000)
    model.load_from(get_weights(path))
    return model

In [26]:
images = load_images(dataset_path)
vit = load_vit()
bit = load_bit()

## Compare predictions

In [39]:
img_tensors = [transforms(img).unsqueeze(0) for img in images]

def batches(img_tensors: List[torch.Tensor], batch_size=32):
    rest = min(1, len(img_tensors) % batch_size)
    for i in range(len(img_tensors) // batch_size + rest):
        yield img_tensors[i * batch_size: (i + 1) * batch_size]

def predict_and_save(model, img_tensors: List[torch.Tensor], target_path: str, batch_size: int = 32):
    outputs = []
    for batch in batches(img_tensors, batch_size):
        img_batch = torch.cat(batch)
        with torch.no_grad():
            output = model(img_batch)
        if isinstance(output, tuple):
            output = output[0]
        outputs.append(output)
    outputs = torch.cat(outputs)
    pickle.dump(outputs.cpu(), open(target_path, 'wb'))

In [40]:
# main computational costs
predict_and_save(vit, img_tensors, target_path='vit_output.pickle', batch_size=5)
predict_and_save(bit, img_tensors, target_path='bit_output.pickle', batch_size=5)

In [42]:
def extract_top_labels_list(outputs: torch.Tensor) -> List[List[str]]:
    # outputs.shape == (images_no, class_no)
    pass

In [None]:
def filter_uninteresting(vit_top_labels: List[List[str]], bit_top_labels: List[List[str]]) \
    -> Tuple[List[PIL.Image.Image], List[List[str]], List[List[str]]]:
    # return interesting_images, corresponding_vit_labels, corresponding_bit_labels
    pass

In [None]:
def show_all_predictions(interesting_images, corresponding_vit_labels, corresponding_bit_labels):
    pass

In [None]:
vit_outputs = pickle.load(open('vit_output.pickle', 'rb'))
bit_outputs = pickle.load(open('bit_output.pickle', 'rb'))

In [None]:
vit_predictions = extract_top_labels_list(vit_outputs)
bit_predictions = extract_top_labels_list(bit_outputs)
images_interesing, vit_interesting, bit_interesting = filter_uninteresting(vit_predictions, bit_predictions)
show_all_predictions(images_interesing, vit_interesting, bit_interesting)

## Compare activation maps

In [None]:
# TODO

## Compare embeddings clusters

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from scipy.spatial import ConvexHull
from sklearn.cluster import AgglomerativeClustering
from sklearn.manifold import MDS

In [None]:
def draw_clustering(embeddings, n_clusters, r=1.5):
    clustering = AgglomerativeClustering(n_clusters).fit(np.array(embeddings))
    colors = [np.random.rand(3,) for _ in range(n_clusters)]
    labels = clustering.labels_
    embeddings_2d = MDS().fit_transform(embeddings)
    
    fig, ax = plt.subplots(figsize=(17, 17))
    
    for i in range(embeddings_2d.shape[0]):
        im = OffsetImage(images[i], zoom=0.1)
        ab = AnnotationBbox(im, embeddings_2d[i], xycoords='data', frameon=False)
        ax.add_artist(ab)

    for i in range(n_clusters):
        index = labels == i
        points = []
        for pt in embeddings_2d[index]:
            points.append(pt)
            points.append(pt + np.array((1, 1)) * r)
            points.append(pt + np.array((1, -1)) * r)
            points.append(pt + np.array((-1, 1)) * r)
            points.append(pt + np.array((-1, -1)) * r)
        points = np.array(points)
        hull = ConvexHull(points)
        hull_points = points[hull.vertices]
        x_hull = np.append(hull_points[:, 0], hull_points[:, 0][0])
        y_hull = np.append(hull_points[:, 1], hull_points[:, 1][0])
        plt.fill(x_hull, y_hull, alpha=0.3, c=colors[i])
        
    ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1])
    plt.show()

In [None]:
images = images_interesting
img_tensors = [transforms(img).unsqueeze(0) for img in images]
img_batch = torch.cat(img_tensors)
assert img_batch.shape[0] > 1

In [None]:
with torch.no_grad():
    vit_embeddings = vit.transformer(img_batch)[0][:, 0, :]

In [None]:
base = list(bit.children())[:-1]
pooling = list(bit.head.children())[:-1]
bit_pruned = nn.Sequential(*base, *pooling)
with torch.no_grad():
    bit_embeddings = bit_pruned(img_batch).squeeze()

In [None]:
draw_clustering(vit_embeddings, 8)

In [None]:
draw_clustering(bit_embeddings, 8, r=8)