In this notebook, we'll visualize the linear projections learned by the Vision Transformer family of models.

In [None]:
# ViT B-16; ImageNet-1k validation top-1 accuracy: 84.017%
!gsutil cp gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz .

# ViT L-16; ImageNet-1k validation top-1 accuracy: 85.716%
!gsutil cp gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz .

## Imports

In [None]:
import tensorflow as tf
import numpy as np
import torch

import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA

## Utilities

In [None]:
def load_jax_params(local_path: str):
    with open(local_path, "rb") as f:
        params_jax = np.load(f)
        params_jax = dict(zip(params_jax.keys(), params_jax.values()))
    return params_jax


def scale_projections(projections: np.ndarray):
    projection_dim = projections.shape[-1]
    patch_h, patch_w, patch_channels = projections.shape[:-1]

    scaled_projections = MinMaxScaler().fit_transform(
        projections.reshape(-1, projection_dim)
    )
    scaled_projections = scaled_projections.reshape(
        patch_h, patch_w, patch_channels, -1
    )
    return scaled_projections


def display_projections(scaled_projections: np.ndarray, save_plot=None):
    fig, axes = plt.subplots(nrows=8, ncols=16, figsize=(13, 8))
    img_count = 0
    limit = 128

    for i in range(8):
        for j in range(16):
            if img_count < limit:
                axes[i, j].imshow(scaled_projections[..., img_count])
                axes[i, j].axis("off")
                img_count += 1

    fig.tight_layout()

    if save_plot:
        fig.savefig(save_plot, dpi=300, bbox_inches="tight")

## ViT B/16

In [None]:
b16_params = load_jax_params(
    "B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz"
)
scaled_projections = scale_projections(b16_params["embedding/kernel"])
display_projections(scaled_projections)

## ViT L/16

In [None]:
l16_params = load_jax_params(
    "L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz"
)
scaled_projections = scale_projections(l16_params["embedding/kernel"])
display_projections(scaled_projections)

## DINO B/16

In [None]:
dino_b16 = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
projections = (
    dino_b16.state_dict()["patch_embed.proj.weight"].numpy().transpose(2, 3, 1, 0)
)
scaled_projections = scale_projections(projections)
display_projections(scaled_projections)

## Observations

* While it's not 100% clear what the projection filters have learned but they do seem to form plausible basis functions as also investigated in the [original ViT paper](https://arxiv.org/abs/2010.11929). 
* Supervised pre-trained ViT B/16 and L/16 have differences in the filters they have learned. DINO pre-trained (self-supervised) ViT B/16 seems to have some immediate differences in the learned filters as well. 