# Compare to timm-pretrained model, then turn into series of unit tests

In [1]:
from vit_prisma.configs import HookedViTConfig
from vit_prisma.models.base_vit import HookedViT

import timm
import torch

In [2]:
prisma_model = HookedViT.from_pretrained("vit_base_patch16_224", 
                                         center_writing_weights=False, 
                                         fold_ln=False, 
                                         fold_value_biases=False,
                                         use_attn_scale=False,
                                         use_split_qkv_input=True,
)
timm_model = timm.create_model('vit_base_patch16_224', pretrained=True)


Loaded pretrained model vit_base_patch16_224 into HookedTransformer


In [3]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize the images to [-1, 1]
    # Resize to 224 x 224
    transforms.Resize((224, 224))
])

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)


Files already downloaded and verified


In [4]:
image, label = next(iter(testloader))
output, cache = prisma_model.run_with_cache(image)

for key in cache.keys():
    print(key, cache[key].shape)



hook_embed torch.Size([1, 196, 768])
hook_pos_embed torch.Size([1, 197, 768])
blocks.0.hook_resid_pre torch.Size([1, 197, 768])
blocks.0.ln1.hook_scale torch.Size([1, 197, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 197, 768])
blocks.0.attn.hook_q torch.Size([1, 197, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 197, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 197, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 197, 197])
blocks.0.attn.hook_pattern torch.Size([1, 12, 197, 197])
blocks.0.attn.hook_z torch.Size([1, 197, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 197, 768])
blocks.0.hook_resid_mid torch.Size([1, 197, 768])
blocks.0.ln2.hook_scale torch.Size([1, 197, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 197, 768])
blocks.0.mlp.hook_pre torch.Size([1, 197, 3072])
blocks.0.mlp.hook_post torch.Size([1, 197, 3072])
blocks.0.hook_mlp_out torch.Size([1, 197, 768])
blocks.0.hook_resid_post torch.Size([1, 197, 768])
blocks.1.hook_resid_pre torch.Size([1, 197, 768])
b

## Patch Embeddings

In [5]:
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.patch_embed.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

activations[0].shape

assert torch.allclose(activations[0], cache['embed'][0])
assert torch.all(activations[0] == cache['embed'][0])

## Position Embeddings

In [6]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.pos_drop.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

assert torch.allclose(activations[0], cache['blocks.0.hook_resid_pre'], atol=1e-6), "Activations differ more than the allowed tolerance"

## Layers

### LayerNorm 1

In [7]:
# First layer

import einops 


activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].norm1.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()


# Assert equal to the first layer
assert torch.allclose(activations[0], cache['blocks.0.ln1.hook_normalized'][0], atol=1e-6), "Activations differ more than the allowed tolerance"

### Attention

In [8]:
# let's compare qkv weights
QKV = timm_model.blocks[0].attn.qkv.weight
W_Q, W_K, W_V = torch.tensor_split(QKV, 3, dim=0)
t_Q = einops.rearrange(W_Q, "(i h) m->h m i", h=12)
t_K = einops.rearrange(W_K, "(i h) m->h m i", h=12)
t_V = einops.rearrange(W_V, "(i h) m->h m i", h=12)

p_Q = prisma_model.blocks[0].attn.W_Q
p_K = prisma_model.blocks[0].attn.W_K
p_V = prisma_model.blocks[0].attn.W_V

assert torch.allclose(p_Q, t_Q, atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(p_K, t_K, atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(p_V, t_V, atol=1e-6), "Activations differ more than the allowed tolerance"

# qkv bias
bias_QKV = timm_model.blocks[0].attn.qkv.bias

b_Q, b_K, b_V = torch.tensor_split(bias_QKV, 3, dim=0)

bt_Q = einops.rearrange(b_Q, "(i h) -> h i", h=12)
bt_K = einops.rearrange(b_K, "(i h) -> h i", h=12)
bt_V = einops.rearrange(b_V, "(i h) -> h i", h=12)

bp_Q = prisma_model.blocks[0].attn.b_Q
bp_K = prisma_model.blocks[0].attn.b_K
bp_V = prisma_model.blocks[0].attn.b_V

assert torch.allclose(bp_Q, bt_Q, atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(bp_K, bt_K, atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(bp_V, bt_V, atol=1e-6), "Activations differ more than the allowed tolerance"

In [22]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].attn.q_norm.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

print(activations[0][0].shape)


assert torch.allclose(activations[0][0], cache['blocks.0.attn.hook_q'][0].permute(1,0,2), atol=1e-2), "Activations differ more than the allowed tolerance"

torch.Size([12, 197, 64])


AssertionError: Activations differ more than the allowed tolerance

In [None]:
 cache['blocks.0.attn.hook_k'][0].shape
print(activations[0].shape)
print(cache['blocks.0.attn.hook_k'][0].permute(1,0,2).shape)

torch.Size([1, 12, 197, 64])
torch.Size([12, 197, 64])


In [None]:
activations[0].dtype

torch.float32

In [19]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].attn.qkv.register_forward_hook(hook_fn)
timm_output = timm_model(image)

print(cache['blocks.0.attn.hook_k'][0].permute(1,0,2).shape)

q, k, v = torch.split(activations[0], 768, dim=-1)

hook_handle.remove()

assert torch.allclose(k, cache['blocks.0.attn.hook_k'][0].permute(1,0,2), atol=1e-6), "Activations differ more than the allowed tolerance"

torch.Size([12, 197, 64])


RuntimeError: The size of tensor a (768) must match the size of tensor b (64) at non-singleton dimension 2

In [20]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].attn.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

assert torch.allclose(activations[0], cache['blocks.0.hook_attn_out'][0], atol=1e-6), "Activations differ more than the allowed tolerance"

AssertionError: Activations differ more than the allowed tolerance

In [None]:
prisma_model.cfg.use_split_qkv_input

False