# Compare to timm-pretrained model, then turn into series of unit tests

In [2]:
from vit_prisma.configs import HookedViTConfig
from vit_prisma.models.base_vit import HookedViT

import timm
import torch

In [None]:
prisma_model = HookedViT.from_pretrained("vit_base_patch16_224", 
                                         center_writing_weights=False, 
                                         fold_ln=False, 
                                         fold_value_biases=False,
                                         use_attn_scale=False,
                                         use_attn_in=True,
)
timm_model = timm.create_model('vit_base_patch16_224', pretrained=True)

In [3]:
timm_model = timm.create_model('vit_base_patch16_224', pretrained=True)

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [4]:
timm_model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [None]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize the images to [-1, 1]
    # Resize to 224 x 224
    transforms.Resize((224, 224))
])

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)


In [None]:
image, label = next(iter(testloader))
output, cache = prisma_model.run_with_cache(image)

for key in cache.keys():
    print(key, cache[key].shape)

## Patch Embeddings

In [None]:
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.patch_embed.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

activations[0].shape

assert torch.allclose(activations[0], cache['embed'][0])
assert torch.all(activations[0] == cache['embed'][0])

## Position Embeddings

In [None]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.pos_drop.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

assert torch.allclose(activations[0], cache['blocks.0.hook_resid_pre'], atol=1e-6), "Activations differ more than the allowed tolerance"

## Layers

### LayerNorm 1

In [None]:

import einops 

activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].norm1.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

print(activations[0].shape)
print(cache['blocks.0.ln1.hook_normalized'].shape)

# Assert equal to the first layer
assert torch.allclose(activations[0], cache['blocks.0.ln1.hook_normalized'][0], atol=1e-6), "Activations differ more than the allowed tolerance"

### Attention

**Weights**

In [None]:
# # let's compare qkv weights
# QKV = timm_model.blocks[0].attn.qkv.weight
# W_Q, W_K, W_V = torch.tensor_split(QKV, 3, dim=0)
# t_Q = einops.rearrange(W_Q, "(i h) m->h m i", h=12)
# t_K = einops.rearrange(W_K, "(i h) m->h m i", h=12)
# t_V = einops.rearrange(W_V, "(i h) m->h m i", h=12)

# p_Q = prisma_model.blocks[0].attn.W_Q
# p_K = prisma_model.blocks[0].attn.W_K
# p_V = prisma_model.blocks[0].attn.W_V

# assert torch.allclose(p_Q, t_Q, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(p_K, t_K, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(p_V, t_V, atol=1e-6), "Activations differ more than the allowed tolerance"

# # qkv bias
# bias_QKV = timm_model.blocks[0].attn.qkv.bias

# b_Q, b_K, b_V = torch.tensor_split(bias_QKV, 3, dim=0)

# bt_Q = einops.rearrange(b_Q, "(i h) -> h i", h=12)
# bt_K = einops.rearrange(b_K, "(i h) -> h i", h=12)
# bt_V = einops.rearrange(b_V, "(i h) -> h i", h=12)

# bp_Q = prisma_model.blocks[0].attn.b_Q
# bp_K = prisma_model.blocks[0].attn.b_K
# bp_V = prisma_model.blocks[0].attn.b_V

# assert torch.allclose(bp_Q, bt_Q, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(bp_K, bt_K, atol=1e-6), "Activations differ more than the allowed tolerance"
# assert torch.allclose(bp_V, bt_V, atol=1e-6), "Activations differ more than the allowed tolerance"

In [None]:
def print_matrix_corner(matrix, rows=1, cols=1):
    """
    Prints the top-left corner of a matrix (tensor) up to the specified number of rows and columns.

    Parameters:
    - matrix (torch.Tensor): The matrix (tensor) from which to print the corner.
    - rows (int): The number of rows to include in the printed corner. Default is 5.
    - cols (int): The number of columns to include in the printed corner. Default is 5.
    """
    # Ensure the matrix is a PyTorch tensor
    if not isinstance(matrix, torch.Tensor):
        print("The input is not a PyTorch tensor.")
        return

    # Get the size of the matrix
    num_rows, num_cols = matrix.shape[:2]

    # Adjust rows and cols if the matrix is smaller than specified dimensions
    rows_to_print = min(rows, num_rows)
    cols_to_print = min(cols, num_cols)

    # Slice the matrix to get the top-left corner
    corner = matrix[:rows_to_print, :cols_to_print]

    print(f"Top-left corner ({rows_to_print}x{cols_to_print}):\n{corner}")


**QKV matrix**

In [None]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].attn.qkv.register_forward_hook(hook_fn)

timm_output = timm_model(image)
hook_handle.remove()

print("timm output", activations[0].shape)
# qkv = activations[0].reshape(-1, 197, 3, 12, 64).permute(2, 0, 3, 1, 4)
qkv = activations[0].reshape(-1, 197, 3, 12, 64).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

print("timm shape", qkv.shape)
# print("prisma shape", cache['blocks.0.attn.hook_qkv'].shape)

print("prisma q shape", cache['blocks.0.attn.hook_q'].shape)

# assert torch.allclose(qkv, cache['blocks.0.attn.hook_qkv'], atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(q, cache['blocks.0.attn.hook_q'], atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(k, cache['blocks.0.attn.hook_k'], atol=1e-6), "Activations differ more than the allowed tolerance"
assert torch.allclose(v, cache['blocks.0.attn.hook_v'], atol=1e-6), "Activations differ more than the allowed tolerance"

**Attention Scores**

In [None]:
scaled_q = q * 64 ** -0.5
timm_attn_scores = scaled_q @ k.transpose(-2,-1)

print("timm attn scores", timm_attn_scores.shape)
print("prisma attn scores", cache['blocks.0.attn.hook_attn_scores'].shape)

assert torch.allclose(timm_attn_scores, cache['blocks.0.attn.hook_attn_scores'], atol=1e-4), "Activations differ more than the allowed tolerance"

**Attention pattern**

In [None]:
timm_attn_pattern = timm_attn_scores.softmax(dim=-1) 

assert torch.allclose(timm_attn_pattern, cache['blocks.0.attn.hook_pattern'], atol=1e-4), "Activations differ more than the allowed tolerance"

**Attention Output**

In [None]:
# First layer
activations = []
def hook_fn(module, input, output):
    activations.append(output)

hook_handle = timm_model.blocks[0].attn.proj.register_forward_hook(hook_fn)
timm_output = timm_model(image)
hook_handle.remove()

print(activations[0].shape)
# print(cache['blocks.0.attn.hook_attn_out'].shape)


assert torch.allclose(cache['blocks.0.attn.hook_result'], activations[0], atol=1e-3), "Activations differ more than the allowed tolerance"


## MLP

In [None]:

#currently only vit_base_patch16_224 supported (config loading issue)
TOLERANCE = 1e-5

model_name = "vit_base_patch16_224"
batch_size = 5
channels = 3
height = 224
width = 224
device = "cuda"

hooked_model = HookedViT.from_pretrained(model_name)
hooked_model.to(device)
timm_model = timm.create_model(model_name, pretrained=True)
timm_model.to(device)

with torch.random.fork_rng():
    torch.manual_seed(1)
    input_image = torch.rand((batch_size, channels, height, width)).to(device)

assert torch.allclose(hooked_model(input_image), timm_model(input_image), atol=TOLERANCE), "Model output diverges!"


In [None]:
hooked_model(input_image).shape

In [None]:
timm_model(input_image).shape