# Exploratory Analysis on the Semantic Space of a Vision Transformer

We want to see the semantic space of TinyCLIP using ablations.

In [5]:
import vit_prisma 
import timm

### Import model

We're just using the visual encoder of TinyCLIP. The visual encoder is a transformer with 10 layers, 12 attention heads, and 256 hidden representation dimensionality.

In [6]:
from vit_prisma.models.base_vit import HookedViT
import torch

TOLERANCE = 1e-5

model_name = "vit_base_patch32_224"
batch_size = 5
channels = 3
height = 224
width = 224
device = "cpu"

hooked_model = HookedViT.from_pretrained(model_name)
hooked_model.to(device)
timm_model = timm.create_model(model_name, pretrained=True)
timm_model.to(device)

with torch.random.fork_rng():
    torch.manual_seed(1)
    input_image = torch.rand((batch_size, channels, height, width)).to(device)



{'n_layers': 12, 'd_model': 768, 'd_head': 64, 'model_name': 'timm/vit_base_patch32_224.augreg_in21k_ft_in1k', 'n_heads': 12, 'd_mlp': 3072, 'activation_name': 'gelu', 'eps': 1e-06, 'original_architecture': 'vit_base_patch32_224', 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 32, 'image_size': 224, 'n_classes': 1000, 'n_params': 88224232, 'return_type': 'class_logits'}
Loaded pretrained model vit_base_patch32_224 into HookedTransformer


In [8]:
assert torch.allclose(hooked_model(input_image), timm_model(input_image), atol=TOLERANCE), "Model output diverges!"

In [None]:
# Finetune the model so it's performing well on CIFAR-100



### Import CIFAR-100

The labels of CIFAR-100 have two levels of granularity: coarse-grained and fine-grained labels. 

We want to see the relationship between the coarse-grained and fine-grained labels inside the net. For example, perhaps the coarse-grained labels tend to be identified around Layer 5, while the fine-grained labels tend to be identified around Label 6. Perhaps there is a semantic hierarchy reflected in a TinyCLIP circuit.

In [18]:
import os
import certifi

# Set the REQUESTS_CA_BUNDLE environment variable to the certifi CA bundle path
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
os.environ['SSL_CERT_DIR'] = '/etc/ssl/certs'

In [20]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize the images to [-1, 1]
    # Resize to 224 x 224
    transforms.Resize((224, 224))
])


testset = datasets.CIFAR100(root='/home/mila/s/sonia.joseph/ViT-Planetarium/data/cifar100', train=False, download=True, transform=transform)

dataloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)


Files already downloaded and verified


In [24]:
superclasses = {
    "aquatic mammals": ["beaver", "dolphin", "otter", "seal", "whale"],
    "fish": ["aquarium fish", "flatfish", "ray", "shark", "trout"],
    "flowers": ["orchids", "poppies", "roses", "sunflowers", "tulips"],
    "food containers": ["bottles", "bowls", "cans", "cups", "plates"],
    "fruit and vegetables": ["apples", "mushrooms", "oranges", "pears", "sweet peppers"],
    "household electrical devices": ["clock", "computer keyboard", "lamp", "telephone", "television"],
    "household furniture": ["bed", "chair", "couch", "table", "wardrobe"],
    "insects": ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
    "large carnivores": ["bear", "leopard", "lion", "tiger", "wolf"],
    "large man-made outdoor things": ["bridge", "castle", "house", "road", "skyscraper"],
    "large natural outdoor scenes": ["cloud", "forest", "mountain", "plain", "sea"],
    "large omnivores and herbivores": ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"],
    "medium-sized mammals": ["fox", "porcupine", "possum", "raccoon", "skunk"],
    "non-insect invertebrates": ["crab", "lobster", "snail", "spider", "worm"],
    "people": ["baby", "boy", "girl", "man", "woman"],
    "reptiles": ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
    "small mammals": ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
    "trees": ["maple", "oak", "palm", "pine", "willow"],
    "vehicles 1": ["bicycle", "bus", "motorcycle", "pickup truck", "train"],
    "vehicles 2": ["lawn-mower", "rocket", "streetcar", "tank", "tractor"]
}


In [25]:
def get_superclass(class_name, superclasses):
    """
    Return the superclass of a given class name from CIFAR-100.

    Parameters:
    - class_name: The name of the class for which to find the superclass.
    - superclasses: A dictionary where keys are superclasses and values are lists of classes.

    Returns:
    - The name of the superclass if found, otherwise "Class not found".
    """
    for superclass, classes in superclasses.items():
        if class_name in classes:
            return superclass
    return "Class not found"

In [26]:
hooked_model = HookedViT.from_pretrained(model_name,
                                        center_unembed=True,
                                        center_writing_weights=True,
                                        fold_ln=True,
                                        refactor_factored_attn_matrices=True)


{'n_layers': 12, 'd_model': 768, 'd_head': 64, 'model_name': 'timm/vit_base_patch32_224.augreg_in21k_ft_in1k', 'n_heads': 12, 'd_mlp': 3072, 'activation_name': 'gelu', 'eps': 1e-06, 'original_architecture': 'vit_base_patch32_224', 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 32, 'image_size': 224, 'n_classes': 1000, 'n_params': 88224232, 'return_type': 'class_logits'}
Loaded pretrained model vit_base_patch32_224 into HookedTransformer


In [None]:

# Return a table that is:
# Rank: Logit: Prob: Class: