# Exploratory Analysis on the Semantic Space of a Vision Transformer

We want to see the semantic space of TinyCLIP using ablations.

In [None]:
import vit_prisma 
import timm

### Import model

We're just using the visual encoder of TinyCLIP. The visual encoder is a transformer with 10 layers, 12 attention heads, and 256 hidden representation dimensionality.

In [1]:
# from vit_prisma.models.base_vit import HookedViT
# import torch

# TOLERANCE = 1e-5

# model_name = "vit_base_patch32_224"
# batch_size = 5
# channels = 3
# height = 224
# width = 224
# device = "cpu"

# hooked_model = HookedViT.from_pretrained(model_name)
# hooked_model.to(device)
# timm_model = timm.create_model(model_name, pretrained=True)
# timm_model.to(device)

# with torch.random.fork_rng():
#     torch.manual_seed(1)
#     input_image = torch.rand((batch_size, channels, height, width)).to(device)



In [2]:
# assert torch.allclose(hooked_model(input_image), timm_model(input_image), atol=TOLERANCE), "Model output diverges!"

### Import CIFAR-100

The labels of CIFAR-100 have two levels of granularity: coarse-grained and fine-grained labels. 

We want to see the relationship between the coarse-grained and fine-grained labels inside the net. For example, perhaps the coarse-grained labels tend to be identified around Layer 5, while the fine-grained labels tend to be identified around Label 6. Perhaps there is a semantic hierarchy reflected in a TinyCLIP circuit.

In [7]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch

from datasets import load_dataset

tiny_imagenet = load_dataset('Maysee/tiny-imagenet', split='train')
print(tiny_imagenet[0])

class HuggingFaceDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        """
        Args:
            hf_dataset: The Hugging Face dataset object.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        # Fetch the image and label from the Hugging Face dataset
        item = self.hf_dataset[idx]
        image = item['image']
        label = item['label']

        # Apply transformations to the image if any
        if self.transform:
            image = self.transform(image)

        # Convert image to tensor if it's a PIL Image
        if isinstance(image, Image.Image):
            image = transforms.ToTensor()(image)

        return image, torch.tensor(label)


Downloading data:   0%|          | 0.00/146M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/10000 [00:00<?, ? examples/s]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64 at 0x7FE378C47160>, 'label': 0}


In [24]:
tiny_imagenet.labels

AttributeError: 'Dataset' object has no attribute 'labels'

In [8]:
import wandb
# Finetune the model so it's performing well on CIFAR-100
from vit_prisma.models.base_vit import HookedViT

model_name = "vit_base_patch32_224"
hooked_model = HookedViT.from_pretrained(model_name)


{'n_layers': 12, 'd_model': 768, 'd_head': 64, 'model_name': 'timm/vit_base_patch32_224.augreg_in21k_ft_in1k', 'n_heads': 12, 'd_mlp': 3072, 'activation_name': 'gelu', 'eps': 1e-06, 'original_architecture': 'vit_base_patch32_224', 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 32, 'image_size': 224, 'n_classes': 1000, 'n_params': 88224232, 'return_type': 'class_logits'}
Loaded pretrained model vit_base_patch32_224 into HookedTransformer


In [14]:
dataloader = HuggingFaceDataset(tiny_imagenet)

In [29]:
names = ["n01443537", "n01629819", "n01641577", "n01644900", "n01698640", "n01742172", "n01768244", "n01770393", "n01774384", "n01774750", "n01784675", "n01882714", "n01910747", "n01917289", "n01944390", "n01950731", "n01983481", "n01984695", "n02002724", "n02056570", "n02058221", "n02074367", "n02094433", "n02099601", "n02099712", "n02106662", "n02113799", "n02123045", "n02123394", "n02124075", "n02125311", "n02129165", "n02132136", "n02165456", "n02226429", "n02231487", "n02233338", "n02236044", "n02268443", "n02279972", "n02281406", "n02321529", "n02364673", "n02395406", "n02403003", "n02410509", "n02415577", "n02423022", "n02437312", "n02480495", "n02481823", "n02486410", "n02504458", "n02509815", "n02666347", "n02669723", "n02699494", "n02769748", "n02788148", "n02791270", "n02793495", "n02795169", "n02802426", "n02808440", "n02814533", "n02814860", "n02815834", "n02823428", "n02837789", "n02841315", "n02843684", "n02883205", "n02892201", "n02909870", "n02917067", "n02927161", "n02948072", "n02950826", "n02963159", "n02977058", "n02988304", "n03014705", "n03026506", "n03042490", "n03085013", "n03089624", "n03100240", "n03126707", "n03160309", "n03179701", "n03201208", "n03255030", "n03355925", "n03373237", "n03388043", "n03393912", "n03400231", "n03404251", "n03424325", "n03444034", "n03447447", "n03544143", "n03584254", "n03599486", "n03617480", "n03637318", "n03649909", "n03662601", "n03670208", "n03706229", "n03733131", "n03763968", "n03770439", "n03796401", "n03814639", "n03837869", "n03838899", "n03854065", "n03891332", "n03902125", "n03930313", "n03937543", "n03970156", "n03977966", "n03980874", "n03983396", "n03992509", "n04008634", "n04023962", "n04070727", "n04074963", "n04099969", "n04118538", "n04133789", "n04146614", "n04149813", "n04179913", "n04251144", "n04254777", "n04259630", "n04265275", "n04275548", "n04285008", "n04311004", "n04328186", "n04356056", "n04366367", "n04371430", "n04376876", "n04398044", "n04399382", "n04417672", "n04456115", "n04465666", "n04486054", "n04487081", "n04501370", "n04507155", "n04532106", "n04532670", "n04540053", "n04560804", "n04562935", "n04596742", "n04598010", "n06596364", "n07056680", "n07583066", "n07614500", "n07615774", "n07646821", "n07647870", "n07657664", "n07695742", "n07711569", "n07715103", "n07720875", "n07749582", "n07753592", "n07768694", "n07871810", "n07873807", "n07875152", "n07920052", "n07975909", "n08496334", "n08620881", "n08742578", "n09193705", "n09246464", "n09256479", "n09332890", "n09428293", "n12267677", "n12520864", "n13001041", "n13652335", "n13652994", "n13719102", "n14991210"]

In [31]:
# Get ImageNet IDs
!curl -O https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 35363  100 35363    0     0   145k      0 --:--:-- --:--:-- --:--:--  145k


In [5]:
import os
import certifi

# Set the REQUESTS_CA_BUNDLE environment variable to the certifi CA bundle path
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
os.environ['SSL_CERT_DIR'] = '/etc/ssl/certs'

In [6]:
import json

# Load the class index file
with open('imagenet_class_index.json', 'r') as f:
    class_index = json.load(f)

# Example: Get the class name and WNID for a specific index
index = '0'  # The index is a string in the JSON keys
wnid, class_name = class_index[index]
print(f"WNID: {wnid}, Class Name: {class_name}")

WNID: n01440764, Class Name: tench


In [7]:
import nltk 
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to
[nltk_data]     /home/mila/s/sonia.joseph/nltk_data...


True

In [67]:
from nltk.corpus import wordnet as wn

for i in range(len(names)):
    wnid = names[i]

    print(wnid)

    # Correct approach: Convert WNID to offset and use synset_from_pos_and_offset
    offset = int(wnid[1:])  # Extract the numerical part of the WNID and convert to integer
    pos = 'n'  # 'n' for noun, which is typical for WNIDs

    # Fetch the synset using the offset and part of speech
    synset = wn.synset_from_pos_and_offset(pos, offset)

    # Get the lemma names (human-readable names)
    eng_names = synset.lemma_names()
    print(eng_names)

n01443537
['goldfish', 'Carassius_auratus']
n01629819
['European_fire_salamander', 'Salamandra_salamandra']
n01641577
['bullfrog', 'Rana_catesbeiana']
n01644900
['tailed_frog', 'bell_toad', 'ribbed_toad', 'tailed_toad', 'Ascaphus_trui']
n01698640
['American_alligator', 'Alligator_mississipiensis']
n01742172
['boa_constrictor', 'Constrictor_constrictor']
n01768244
['trilobite']
n01770393
['scorpion']
n01774384
['black_widow', 'Latrodectus_mactans']
n01774750
['tarantula']
n01784675
['centipede']
n01882714
['koala', 'koala_bear', 'kangaroo_bear', 'native_bear', 'Phascolarctos_cinereus']
n01910747
['jellyfish']
n01917289
['brain_coral']
n01944390
['snail']
n01950731
['sea_slug', 'nudibranch']
n01983481
['American_lobster', 'Northern_lobster', 'Maine_lobster', 'Homarus_americanus']
n01984695
['spiny_lobster', 'langouste', 'rock_lobster', 'crawfish', 'crayfish', 'sea_crawfish']
n02002724
['black_stork', 'Ciconia_nigra']
n02056570
['king_penguin', 'Aptenodytes_patagonica']
n02058221
['albatr

In [74]:
tabby_cat = 'n02121808'

# Find the synset for "cat"
# cat_synset = wn.synset(tabby_cat)  # This is an illustrative example; use the correct synset

# Correct approach: Convert WNID to offset and use synset_from_pos_and_offset
offset = int(tabby_cat[1:])  # Extract the numerical part of the WNID and convert to integer
pos = 'n'  # 'n' for noun, which is typical for WNIDs

# Fetch the synset using the offset and part of speech
cat_synset = wn.synset_from_pos_and_offset(pos, offset)

# Get the lemma names (human-readable names)
# eng_names = synset.lemma_names()
# print(eng_names)
    
# Find all hyponyms of the cat synset
cat_hyponyms = cat_synset.hyponyms()

print(cat_hyponyms)

# Print out some of the hyponyms
for hyponym in cat_hyponyms[:10]:  #
    print(hyponym)

[Synset('abyssinian.n.01'), Synset('alley_cat.n.01'), Synset('angora.n.04'), Synset('burmese_cat.n.01'), Synset('egyptian_cat.n.01'), Synset('kitty.n.04'), Synset('maltese.n.03'), Synset('manx.n.02'), Synset('mouser.n.01'), Synset('persian_cat.n.01'), Synset('siamese_cat.n.01'), Synset('tabby.n.01'), Synset('tabby.n.02'), Synset('tiger_cat.n.02'), Synset('tom.n.02'), Synset('tortoiseshell.n.03')]
Synset('abyssinian.n.01')
Synset('alley_cat.n.01')
Synset('angora.n.04')
Synset('burmese_cat.n.01')
Synset('egyptian_cat.n.01')
Synset('kitty.n.04')
Synset('maltese.n.03')
Synset('manx.n.02')
Synset('mouser.n.01')
Synset('persian_cat.n.01')


In [88]:
from nltk.corpus import wordnet as wn

# Define a function to recursively accumulate all descendant hyponyms of a given synset
def accumulate_descendant_hyponyms(synset, accumulated_hyponyms=None):
    if accumulated_hyponyms is None:
        accumulated_hyponyms = set()
    for hyponym in synset.hyponyms():
        accumulated_hyponyms.add(hyponym)
        accumulate_descendant_hyponyms(hyponym, accumulated_hyponyms)
    return accumulated_hyponyms

# Fetch the Felidae synset
felidae_synset = wn.synset('felidae.n.01')

# Use the function to accumulate all descendant hyponyms of Felidae
all_felidae_hyponyms = accumulate_descendant_hyponyms(felidae_synset)

# Extract WNIDs for these hyponyms
felidae_wnids = ["n" + str(hyponym.offset()).zfill(8) for hyponym in all_felidae_hyponyms]

print(f"Number of Felidae WNIDs: {len(felidae_wnids)}")
print("Some example WNIDs:", felidae_wnids[:10])


Number of Felidae WNIDs: 0
Some example WNIDs: []


In [25]:
def get_superclass(class_name, superclasses):
    """
    Return the superclass of a given class name from CIFAR-100.

    Parameters:
    - class_name: The name of the class for which to find the superclass.
    - superclasses: A dictionary where keys are superclasses and values are lists of classes.

    Returns:
    - The name of the superclass if found, otherwise "Class not found".
    """
    for superclass, classes in superclasses.items():
        if class_name in classes:
            return superclass
    return "Class not found"

len(names)

# Query hierarchy

In [38]:
from nltk.corpus import wordnet as wn

def find_labels_at_level(synset, level, current_level=0):
    """
    Recursively find all labels (synsets) at a given level of the hierarchy.

    Args:
    - synset: The current synset from which to explore.
    - level: The target level depth to find synsets.
    - current_level: The current level depth in the recursion.

    Returns:
    - A list of synsets found at the specified level.
    """
    # Base case: If the current level matches the target level, return the current synset
    if current_level == level:
        return [synset]
    
    # Recursive case: Explore child synsets (hyponyms) if not at the target level yet
    elif current_level < level:
        labels = []
        for hyponym in synset.hyponyms():
            labels.extend(find_labels_at_level(hyponym, level, current_level + 1))
        return labels
    
    # If current level is somehow greater than target (should not happen in correct use), return empty list
    else:
        return []

# Example usage
root_synset = wn.synset('entity.n.01')  # Start from the root noun synset
level = 3  # Specify the target level depth
labels_at_level = find_labels_at_level(root_synset, level)

# Print out some results
for synset in labels_at_level[:10]:  # Limiting output for brevity
    print(f"{synset.name()}: {synset.definition()}")


ballast.n.03: an attribute that tends to give stability in character and morals; something that steadies the mind or feelings
character.n.09: (genetics) an attribute (structural or functional) that is determined by a gene or group of genes
cheerfulness.n.01: the quality of being cheerful and dispelling gloom
common_denominator.n.02: an attribute that is common to all members of a category
depth.n.06: the attribute or quality of being deep, strong, or intense
eidos.n.01: (anthropology) the distinctive expression of the cognitive or intellectual character of a culture or a social group
ethos.n.01: (anthropology) the distinctive spirit of a culture or an era
human_nature.n.01: the shared psychological attributes of humankind that are assumed to be shared by all human beings
inheritance.n.04: any attribute or immaterial possession that is inherited from ancestors
personality.n.01: the complex of all the attributes--behavioral, temperamental, emotional and mental--that characterize a unique

In [44]:
import pandas as pd

all_list=[["-" for j in range(7)] for i in range(200)]
for i in range(200):
    wnid = names[i]

    # Correct approach: Convert WNID to offset and use synset_from_pos_and_offset
    offset = int(wnid[1:])  # Extract the numerical part of the WNID and convert to integer
    pos = 'n'  # 'n' for noun, which is typical for WNIDs
    synset=wn.synset_from_pos_and_offset(pos, offset)
    hyper_list=[]
    while synset.hypernyms():
        synset = synset.hypernyms()[0]
        hyper_list.append(synset.name())
    hyper_list.insert(0,'null')
    hyper_list.insert(0,'null')
    all_list[i][:]=hyper_list[:-7:-1]

df=pd.DataFrame(all_list)

In [64]:
df.iloc[:,5].value_counts()


5
instrumentality.n.03      68
organism.n.01             56
structure.n.01            18
commodity.n.01            16
covering.n.02              6
nutriment.n.01             5
plant_part.n.01            4
null                       3
produce.n.01               3
linear_unit.n.01           2
meat.n.01                  2
social_gathering.n.01      1
mountain.n.01              1
mass_unit.n.01             1
geographic_point.n.01      1
geographical_area.n.01     1
ridge.n.01                 1
helping.n.01               1
beverage.n.01              1
baked_goods.n.01           1
block.n.01                 1
foodstuff.n.02             1
dance_music.n.02           1
award.n.02                 1
plaything.n.01             1
sheet.n.06                 1
coloring_material.n.01     1
Name: count, dtype: int64