In [19]:
!pip install --upgrade transformer-lens einops matplotlib

import sys
import os
import torch
import torch.nn as nn
import transformer_lens
from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
from google.colab import files
import matplotlib.pyplot as plt
import numpy as np
import shutil

new_lib_path = '/usr/local/lib/python3.10/dist-packages'
if new_lib_path not in sys.path: sys.path.insert(0, new_lib_path)

MODULUS = 113; vocab_size = MODULUS + 2; MAX_SEQ_LEN = 4; D_MODEL = 128; N_LAYERS = 3; N_HEADS = 4; D_HEAD = D_MODEL // N_HEADS; D_MLP = 512
cfg = HookedTransformerConfig(n_layers=N_LAYERS, d_model=D_MODEL, n_heads=N_HEADS, d_head=D_HEAD, d_mlp=D_MLP, n_ctx=MAX_SEQ_LEN, d_vocab=vocab_size, act_fn="relu", normalization_type="LN", seed=42)

hooked_model = HookedTransformer(cfg)

device = "cpu"
state_dict = torch.load("modular_addition.pt", map_location=device); new_state_dict = {};
for k, v in state_dict.items():
    if k == "token_embedding.weight": new_state_dict["embed.W_E"] = v
    elif k == "positional_embedding.weight": new_state_dict["pos_embed.W_pos"] = v
    elif k == "output_head.weight": new_state_dict["unembed.W_U"] = v.T
    elif k == "output_head.bias": new_state_dict["unembed.b_U"] = v
    elif k == "final.weight": new_state_dict["ln_final.w"] = v
    elif k == "final.bias": new_state_dict["ln_final.b"] = v
    else:
        k_orig=k; k=k.replace("layers.","blocks."); k=k.replace("norm1.weight","ln1.w"); k=k.replace("norm1.bias","ln1.b"); k=k.replace("norm2.weight","ln2.w"); k=k.replace("norm2.bias","ln2.b"); k=k.replace("linear1.weight","mlp.W_in"); k=k.replace("linear1.bias","mlp.b_in"); k=k.replace("linear2.weight","mlp.W_out"); k=k.replace("linear2.bias","mlp.b_out")
        if "self_attn.in_proj_weight" in k_orig:
            idx=k_orig.split('.')[1]; W_Q,W_K,W_V=torch.chunk(v,3,dim=0);
            new_state_dict[f"blocks.{idx}.attn.W_Q"]=einops.rearrange(W_Q,'d(h dh)->h d dh',h=N_HEADS);
            new_state_dict[f"blocks.{idx}.attn.W_K"]=einops.rearrange(W_K,'d(h dh)->h d dh',h=N_HEADS);
            new_state_dict[f"blocks.{idx}.attn.W_V"]=einops.rearrange(W_V,'d(h dh)->h d dh',h=N_HEADS)
        elif "self_attn.in_proj_bias" in k_orig:
            idx=k_orig.split('.')[1]; b_Q,b_K,b_V=torch.chunk(v,3,dim=0);
            new_state_dict[f"blocks.{idx}.attn.b_Q"]=einops.rearrange(b_Q,'(h dh)->h dh',h=N_HEADS);
            new_state_dict[f"blocks.{idx}.attn.b_K"]=einops.rearrange(b_K,'(h dh)->h dh',h=N_HEADS);
            new_state_dict[f"blocks.{idx}.attn.b_V"]=einops.rearrange(b_V,'(h dh)->h dh',h=N_HEADS)
        elif "self_attn.out_proj.weight" in k_orig:
            idx=k_orig.split('.')[1]; new_state_dict[f"blocks.{idx}.attn.W_O"]=einops.rearrange(v,'(h dh)d->h dh d',h=N_HEADS)
        elif "self_attn.out_proj.bias" in k_orig:
            idx=k_orig.split('.')[1]; new_state_dict[f"blocks.{idx}.attn.b_O"]=v
        elif "attn" not in k:
            if "mlp.W_in" in k or "mlp.W_out" in k: new_state_dict[k]=v.T
            else: new_state_dict[k]=v

hooked_model.load_state_dict(new_state_dict, strict=False)
hooked_model.to(device); hooked_model.eval()


tokens = [str(i) for i in range(MODULUS)] + ['+', '=']; token_to_int = {t: i for i, t in enumerate(tokens)}; int_to_token = {i: t for t, i in token_to_int.items()}; print("--- Vocabulary created! ---")

# define which inputs you'd like to plot. it'd be good to use a variety, including ones that sum over the modulus
input_pairs = [
    (5, 8),
    (100, 10),
    (25, 50),
    (7, 99),
    (90, 90),
    (50, 63)
]

output_base_dir = "summary_plots"
if os.path.exists(output_base_dir): shutil.rmtree(output_base_dir)
os.makedirs(output_base_dir)


for a, b in input_pairs:
    pair_filename = os.path.join(output_base_dir, f"summary_{a}_plus_{b}.png")

    prompt_tokens = torch.tensor([[token_to_int[str(a)], token_to_int['+'], token_to_int[str(b)], token_to_int['=']]], device=device)
    token_list = [int_to_token[i.item()] for i in prompt_tokens[0]]
    final_token_index = len(token_list) - 1 #

    logits, cache = hooked_model.run_with_cache(prompt_tokens)

    layer_patterns = []
    for i in range(N_LAYERS):
        layer_patterns.append(cache["pattern", i].squeeze(0))
    all_attention_patterns = torch.stack(layer_patterns, dim=0)

    fig, axs = plt.subplots(N_LAYERS, N_HEADS, figsize=(20, 10), sharey=True)
    fig.suptitle(f"Attention from '=' Token for Input: {a} + {b}", fontsize=20)

    for layer_idx in range(N_LAYERS):
        for head_idx in range(N_HEADS):
            attention_from_equals = all_attention_patterns[layer_idx, head_idx, final_token_index, :] \
                                            .cpu().numpy()

            ax = axs[layer_idx, head_idx]
            ax.bar(token_list, attention_from_equals, color='steelblue')
            ax.set_title(f"Layer {layer_idx}, Head {head_idx}")
            ax.set_ylim(0, 1.0)
            ax.tick_params(axis='x', rotation=45)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(pair_filename)
    plt.close(fig)

zip_filename = "all_summary_plots.zip"
shutil.make_archive(output_base_dir, 'zip', output_base_dir)
files.download(f"{output_base_dir}.zip")


--- Built HookedTransformer structure ---
Using device: cpu
Moving model to device:  cpu
--- SUCCESSFULLY LOADED WEIGHTS INTO HOOKEDTRANSFORMER! ---
--- Vocabulary created! ---

--- Starting Summary Plot Generation for 6 pairs ---

Processing Input: a=5, b=8
   Saved summary plot: summary_plots/summary_5_plus_8.png

Processing Input: a=100, b=10
   Saved summary plot: summary_plots/summary_100_plus_10.png

Processing Input: a=25, b=50
   Saved summary plot: summary_plots/summary_25_plus_50.png

Processing Input: a=7, b=99
   Saved summary plot: summary_plots/summary_7_plus_99.png

Processing Input: a=90, b=90
   Saved summary plot: summary_plots/summary_90_plus_90.png

Processing Input: a=50, b=63
   Saved summary plot: summary_plots/summary_50_plus_63.png

--- Finished Summary Plot Generation ---

Zipping results into all_summary_plots.zip...
DEBUG: Zip file created at all_summary_plots.zip
DEBUG: Preparing to download...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DEBUG: Download command finished.
