# Compare to timm-pretrained model, then turn into series of unit tests

In [1]:
from vit_prisma.configs import HookedViTConfig
from vit_prisma.models.base_vit import HookedViT

import timm
import torch

In [2]:
prisma_model = HookedViT.from_pretrained("vit_base_patch32_224", 
                                         center_writing_weights=False, 
                                         fold_ln=False, 
                                         fold_value_biases=False,
                                         use_attn_scale=False,
                                         use_attn_in=True,
)
# timm_model = timm.create_model('vit_base_patch16_224', pretrained=True)

{'n_layers': 12, 'd_model': 768, 'd_head': 64, 'model_name': 'timm/vit_base_patch32_224.augreg_in21k_ft_in1k', 'n_heads': 12, 'd_mlp': 3072, 'activation_name': 'gelu', 'eps': 1e-06, 'original_architecture': 'vit_base_patch32_224', 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 32, 'image_size': 224, 'n_classes': 1000, 'n_params': 88224232, 'return_type': 'class_logits'}
Loaded pretrained model vit_base_patch32_224 into HookedTransformer


In [13]:
timm_model = timm.create_model('vit_base_patch32_224', pretrained=True)

In [4]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize the images to [-1, 1]
    # Resize to 224 x 224
    transforms.Resize((224, 224))
])

testset = datasets.CIFAR10(root='/home/mila/s/sonia.joseph/ViT-Planetarium/data/cifar10', train=False, download=False, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)


In [5]:
image, label = next(iter(testloader))
output, cache = prisma_model.run_with_cache(image)

# for key in cache.keys():
#     print(key, cache[key].shape)

## Patch Embeddings

In [6]:
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.patch_embed.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

activations[0].shape

assert torch.allclose(activations[0], cache['embed'][0])
assert torch.all(activations[0] == cache['embed'][0])

## Position Embeddings

In [7]:
for k in timm_model.state_dict().keys():
    print(k)

cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blocks.3.norm1.weight
blocks.3.norm1.bias
blocks.3.attn.qkv.weight
blocks.3.attn.qkv.bias
blocks.3.attn.proj.wei

In [8]:
prisma_model.cfg

HookedViTConfig(n_layers=12, d_model=768, d_head=64, d_mlp=3072, model_name='timm/vit_base_patch32_224.augreg_in21k_ft_in1k', n_heads=12, activation_name='gelu', d_vocab=-1, eps=1e-06, use_attn_result=False, use_attn_scale=True, use_split_qkv_input=False, use_hook_mlp_in=False, use_attn_in=False, use_local_attn=False, original_architecture='vit_base_patch32_224', from_checkpoint=False, checkpoint_index=None, checkpoint_label_type=None, checkpoint_value=None, tokenizer_name=None, window_size=None, attn_types=None, init_mode='gpt2', normalization_type='LN', device='cpu', n_devices=1, attention_dir='bidirectional', attn_only=False, seed=None, initializer_range=0.02, init_weights=True, scale_attn_by_inverse_layer_idx=False, positional_embedding_type='standard', final_rms=False, d_vocab_out=-1, parallel_attn_mlp=False, rotary_dim=None, n_params=88224232, use_hook_tokens=False, gated_mlp=False, default_prepend_bos=True, dtype=torch.float32, tokenizer_prepends_bos=None, n_key_value_heads=None

In [9]:
timm_model.state_dict()['pos_embed']
cache['pos_embed']

tensor([[[-4.8371e-02,  1.8895e-01, -2.4998e-02,  ..., -1.1429e-02,
          -2.6565e-02, -1.1025e-02],
         [-5.2308e-02,  1.4553e+00,  2.3567e-03,  ...,  1.2978e-01,
           1.3742e-02, -7.8761e-03],
         [-4.4628e-01,  2.6859e+00, -5.5638e-02,  ...,  8.2297e-02,
           9.2797e-02,  4.3837e-04],
         ...,
         [ 5.9315e-01,  4.6065e-01, -4.4948e-02,  ..., -3.0244e-02,
          -4.2538e-02,  3.7644e-02],
         [ 5.5119e-01,  1.6894e+00, -5.4843e-02,  ..., -6.5307e-02,
           5.1351e-02,  1.0790e-02],
         [ 4.2329e-01,  1.1742e-03, -9.6804e-03,  ..., -9.4712e-02,
           1.9008e-02,  3.3273e-02]]])

In [10]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.pos_drop.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()


assert torch.allclose(activations[0], cache['blocks.0.hook_resid_pre'], atol=1e-2), "Activations differ more than the allowed tolerance"

## Layers

### LayerNorm 1

In [14]:

import einops 

activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].norm1.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

print(activations[0].shape)
print(cache['blocks.0.ln1.hook_normalized'].shape)

# Assert equal to the first layer
assert torch.allclose(activations[0], cache['blocks.0.ln1.hook_normalized'][0], atol=1e-6), "Activations differ more than the allowed tolerance"

torch.Size([1, 50, 768])
torch.Size([1, 50, 768])


### Attention

**Weights**

In [None]:
# # let's compare qkv weights
# QKV = timm_model.blocks[0].attn.qkv.weight
# W_Q, W_K, W_V = torch.tensor_split(QKV, 3, dim=0)
# t_Q = einops.rearrange(W_Q, "(i h) m->h m i", h=12)
# t_K = einops.rearrange(W_K, "(i h) m->h m i", h=12)
# t_V = einops.rearrange(W_V, "(i h) m->h m i", h=12)

# p_Q = prisma_model.blocks[0].attn.W_Q
# p_K = prisma_model.blocks[0].attn.W_K
# p_V = prisma_model.blocks[0].attn.W_V

# assert torch.allclose(p_Q, t_Q, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(p_K, t_K, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(p_V, t_V, atol=1e-6), "Activations differ more than the allowed tolerance"

# # qkv bias
# bias_QKV = timm_model.blocks[0].attn.qkv.bias

# b_Q, b_K, b_V = torch.tensor_split(bias_QKV, 3, dim=0)

# bt_Q = einops.rearrange(b_Q, "(i h) -> h i", h=12)
# bt_K = einops.rearrange(b_K, "(i h) -> h i", h=12)
# bt_V = einops.rearrange(b_V, "(i h) -> h i", h=12)

# bp_Q = prisma_model.blocks[0].attn.b_Q
# bp_K = prisma_model.blocks[0].attn.b_K
# bp_V = prisma_model.blocks[0].attn.b_V

# assert torch.allclose(bp_Q, bt_Q, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(bp_K, bt_K, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(bp_V, bt_V, atol=1e-6), "Activations differ more than the allowed tolerance"

In [15]:
def print_matrix_corner(matrix, rows=1, cols=1):
    """
    Prints the top-left corner of a matrix (tensor) up to the specified number of rows and columns.

    Parameters:
    - matrix (torch.Tensor): The matrix (tensor) from which to print the corner.
    - rows (int): The number of rows to include in the printed corner. Default is 5.
    - cols (int): The number of columns to include in the printed corner. Default is 5.
    """
    # Ensure the matrix is a PyTorch tensor
    if not isinstance(matrix, torch.Tensor):
        print("The input is not a PyTorch tensor.")
        return

    # Get the size of the matrix
    num_rows, num_cols = matrix.shape[:2]

    # Adjust rows and cols if the matrix is smaller than specified dimensions
    rows_to_print = min(rows, num_rows)
    cols_to_print = min(cols, num_cols)

    # Slice the matrix to get the top-left corner
    corner = matrix[:rows_to_print, :cols_to_print]

    print(f"Top-left corner ({rows_to_print}x{cols_to_print}):\n{corner}")


**QKV matrix**

In [17]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].attn.qkv.register_forward_hook(hook_fn)

timm_output = timm_model(image)
hook_handle.remove()

print("timm output", activations[0].shape)
# qkv = activations[0].reshape(-1, 197, 3, 12, 64).permute(2, 0, 3, 1, 4)
# qkv = activations[0].reshape(-1, 197, 3, 12, 64).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

print("timm shape", qkv.shape)
# print("prisma shape", cache['blocks.0.attn.hook_qkv'].shape)

print("prisma q shape", cache['blocks.0.attn.hook_q'].shape)

# assert torch.allclose(qkv, cache['blocks.0.attn.hook_qkv'], atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(q, cache['blocks.0.attn.hook_q'], atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(k, cache['blocks.0.attn.hook_k'], atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(v, cache['blocks.0.attn.hook_v'], atol=1e-6), "Activations differ more than the allowed tolerance"

timm output torch.Size([1, 50, 2304])


NameError: name 'qkv' is not defined

**Attention Scores**

In [None]:
scaled_q = q * 64 ** -0.5
timm_attn_scores = scaled_q @ k.transpose(-2,-1)

print("timm attn scores", timm_attn_scores.shape)
print("prisma attn scores", cache['blocks.0.attn.hook_attn_scores'].shape)

assert torch.allclose(timm_attn_scores, cache['blocks.0.attn.hook_attn_scores'], atol=1e-4), "Activations differ more than the allowed tolerance"

**Attention pattern**

In [None]:
timm_attn_pattern = timm_attn_scores.softmax(dim=-1) 

assert torch.allclose(timm_attn_pattern, cache['blocks.0.attn.hook_pattern'], atol=1e-4), "Activations differ more than the allowed tolerance"

**Attention Output**

In [None]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].attn.proj.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

print(activations[0].shape)
# print(cache['blocks.0.attn.hook_attn_out'].shape)


assert torch.allclose(cache['blocks.0.attn.hook_result'], activations[0], atol=1e-3), "Activations differ more than the allowed tolerance"


## MLP

In [None]:

#currently only vit_base_patch16_224 supported (config loading issue)
TOLERANCE = 1e-5

model_name = "vit_base_patch16_224"
batch_size = 5
channels = 3
height = 224
width = 224
device = "cuda"

hooked_model = HookedViT.from_pretrained(model_name)
hooked_model.to(device)
timm_model = timm.create_model(model_name, pretrained=True)
timm_model.to(device)

with torch.random.fork_rng():
    torch.manual_seed(1)
    input_image = torch.rand((batch_size, channels, height, width)).to(device)

assert torch.allclose(hooked_model(input_image), timm_model(input_image), atol=TOLERANCE), "Model output diverges!"


In [None]:
hooked_model(input_image).shape

In [None]:
timm_model(input_image).shape