<a href="https://colab.research.google.com/github/oscaryas/MAT1510/blob/main/Attention_Head_Dynamics_Clustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TASK

Your task is to project the tokens in the residual stream to the unit sphere and visualize how the tokens move along the sphere. Before starting, make sure to be using a GPU or TPU runtime.

## Part 1: Load the Model

Use the `transformers` library to load the `Qwen3-4B` model and its tokenizer. The model can be found at: https://huggingface.co/Qwen/Qwen3-4B. It may take some time to download.  

## Part 2: Trace the Model

Print the loaded model. Draw a diagram of how the tokens flow through a decoder layer in this model. The diagram can abstract away the mathematical equations but should include the locations of the skip connections and the respective modules that were printed in the print statement.

Hint: The `Qwen3DecoderLayer` class found here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py could be useful.

The diagram can be uploaded as an image with your submission on Quercus.

## Part 3: Creating Forward Hooks

For each decoder layer, we want to save the itermediate representations of the tokens. This can be done via *forward hooks*. https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html.

Understand how these work. Using the diagram you created in Part 2,

1. Determine which module you should attach a forward hook to if you want to save the representations of the tokens at the input to each decoder layer.

2. Create a forward hook that does this and saves the representations into a Python dictionary that is indexed by the layer number.

Attach the forward hooks to the model.

## Part 4: Prompting the LLM

Pass the following prompt through the LLM:

`The Eiffel Tower was originally intended to be dismantled after twenty years. Neural networks sometimes behave unpredictably when initialized with poor random seeds. Octopuses have three hearts, and two of them stop beating when they swim. If you multiply two rotation matrices, the result is another valid rotation matrix. The smell of fresh rain is partly caused by a compound called geosmin. In LaTeX, the tikz package lets you create vector graphics directly within documents. Bananas are technically berries, but strawberries are not. Quantum entanglement has been experimentally demonstrated over distances exceeding 1,000 kilometers. The longest chess game theoretically possible under current rules is 5,949 moves. Clouds can weigh millions of kilograms, yet remain suspended in the air. The prime number theorem gives an asymptotic estimate for the distribution of primes. Some species of ants farm aphids to harvest their sugary secretions. In PyTorch, gradient accumulation is often used when the GPU cannot fit large batches. Shakespeare invented, or at least popularized, over 1,700 English words. Black holes can theoretically evaporate over time through Hawking radiation. The mitochondrion is often called the powerhouse of the cell, though chloroplasts generate energy too. Airplane contrails can influence local climate patterns by trapping heat. In category theory, a functor maps objects and morphisms between categories. The Great Wall of China is not a single continuous wall but a series of fortifications. Sorting algorithms like quicksort and mergesort have different average- and worst-case complexities. Pufferfish inflate themselves by rapidly ingesting water (or air when on land). In probability, the law of large numbers states that sample averages converge to expected values. Mount Everest continues to grow taller by a few millimeters each year due to tectonic activity. The Collatz conjecture remains unsolved despite its deceptively simple definition. Some frogs can survive being frozen solid and thawing back to life.`

## Part 5: Visualizing the Representations

For each layer:

1. Take the saved representations, apply PCA, and extract the top 3 principal components.

2. Project each token onto the top three principal components, giving a 3xN dimensional tensor where N is the number of tokens inside the prompt.

3. Normalize each of the projected N tokens to have unit norm so that they are on the unit sphere.

The rest of the plotting code is provided. Upload the `tokens.gif` with your submission.

In 1-2 sentences, explain what you see happening to the tokens as they move through the model?



In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import os
import numpy as np
import imageio

## Part 1

In [None]:
MODEL = "Qwen/Qwen3-4B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)

# Load model
model = AutoModelForCausalLM.from_pretrained(MODEL, dtype = torch.float32)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

## Part 2

In [None]:
print(model)

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2560)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2560, out_features=4096, bias=False)
          (k_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=2560, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (up_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (down_proj): Linear(in_features=9728, out_features=2560, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
        (post_attention_layernorm): Qwe

## Part 3

In [None]:
print(model.config.num_hidden_layers)
for modules in model.modules():
  try:
    modules.self_attn
    print(modules.input_layernorm)

  except AttributeError:
    pass

36
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RMSNorm((2560,), eps=1e-06)
Qwen3RM

In [None]:
# # REGISTER FORWARD HOOKS
players = {}
handles = []



# # def fw_hooks(module, input, output, *, layer_id):

def make_hook(layer_id):
    def fw_hook(module, input, output):
        layers[layer_id] = output
        print(f"Layer: {layer_id} Output: {output}")

    return fw_hook

layer_idx = 0
for modules in model.modules():
  # try:
  #   modules.sparse
  #   modules.name = "Embedding"
  #   modules.register_forward_pre_hook(hook = make_hook(layer_idx))

  # except AttributeError:
  #   pass

  try:
    hooks = make_hook(layer_idx)
    handle = modules.mlp.register_forward_hook(hooks)
    handles.append(handle)
    layer_idx += 1

  except AttributeError:
    continue

## Part 4

In [None]:
handles

[<torch.utils.hooks.RemovableHandle at 0x78f1b80f3440>,
 <torch.utils.hooks.RemovableHandle at 0x78f5f0678e00>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b4e07f20>,
 <torch.utils.hooks.RemovableHandle at 0x78f54028de80>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b671eba0>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b675fd70>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b8687260>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b89b2e40>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b46958b0>,
 <torch.utils.hooks.RemovableHandle at 0x78f4e8612c60>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b5e60350>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b58a8770>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b4685f40>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b4b593d0>,
 <torch.utils.hooks.RemovableHandle at 0x78f504c49430>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b3f37680>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b3f37650>,
 <torch.utils.hooks.RemovableHandle at 0x78f1b3f

In [None]:
prompt = "The Eiffel Tower was originally intended to be dismantled after twenty years. Neural networks sometimes behave unpredictably when initialized with poor random seeds. Octopuses have three hearts, and two of them stop beating when they swim. If you multiply two rotation matrices, the result is another valid rotation matrix. The smell of fresh rain is partly caused by a compound called geosmin. In LaTeX, the tikz package lets you create vector graphics directly within documents. Bananas are technically berries, but strawberries are not. Quantum entanglement has been experimentally demonstrated over distances exceeding 1,000 kilometers. The longest chess game theoretically possible under current rules is 5,949 moves. Clouds can weigh millions of kilograms, yet remain suspended in the air. The prime number theorem gives an asymptotic estimate for the distribution of primes. Some species of ants farm aphids to harvest their sugary secretions. In PyTorch, gradient accumulation is often used when the GPU cannot fit large batches. Shakespeare invented, or at least popularized, over 1,700 English words. Black holes can theoretically evaporate over time through Hawking radiation. The mitochondrion is often called the powerhouse of the cell, though chloroplasts generate energy too. Airplane contrails can influence local climate patterns by trapping heat. In category theory, a functor maps objects and morphisms between categories. The Great Wall of China is not a single continuous wall but a series of fortifications. Sorting algorithms like quicksort and mergesort have different average- and worst-case complexities. Pufferfish inflate themselves by rapidly ingesting water (or air when on land). In probability, the law of large numbers states that sample averages converge to expected values. Mount Everest continues to grow taller by a few millimeters each year due to tectonic activity. The Collatz conjecture remains unsolved despite its deceptively simple definition. Some frogs can survive being frozen solid and thawing back to life."

# PROMPT THE MODEL WITH THE ABOVE PROMPT
inputs = tokenizer(prompt, return_tensors="pt")
print("omegalul")
generate_ids = model(inputs.input_ids, attention_mask = inputs["attention_mask"])
print("bruh")
# print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0][len(inputs['input_ids'][0]):])



omegalul
Layer: 0 Output: ()
Layer: 0 Output: tensor([[[ 1.0146e+01, -1.9501e+00,  1.7076e+00,  ..., -1.3879e-01,
           4.9293e-02, -1.0761e-01],
         [ 2.4134e+00, -6.6851e-01,  4.5660e-01,  ..., -3.4016e-02,
          -1.3210e-02, -7.6457e-02],
         [ 2.1944e+00, -8.9011e-01,  5.5695e-01,  ...,  1.8625e-01,
           5.2348e-03,  4.9113e-02],
         ...,
         [ 8.3279e-01, -4.3003e-01,  3.0834e-01,  ...,  8.0536e-02,
           1.6427e-01, -1.6069e-01],
         [ 4.6935e-01, -7.7910e-01,  3.0211e-01,  ...,  2.9314e-01,
          -8.1940e-02, -4.3056e-02],
         [ 1.2315e+00, -2.0325e-01,  3.3242e-01,  ..., -3.6013e-03,
           3.0837e-02,  3.3227e-03]]], grad_fn=<UnsafeViewBackward0>)
Layer: 0 Output: tensor([[[ 1.0146e+01, -1.9501e+00,  1.7076e+00,  ..., -1.3879e-01,
           4.9293e-02, -1.0761e-01],
         [ 2.4134e+00, -6.6851e-01,  4.5660e-01,  ..., -3.4016e-02,
          -1.3210e-02, -7.6457e-02],
         [ 2.1944e+00, -8.9011e-01,  5.5695e-01,  

In [None]:
for h in handles:
  h.remove()

In [None]:
print(layers[2])

(tensor([[[ 0.1435,  0.0757, -0.0287,  ...,  0.0270, -0.0327,  0.0147],
         [ 0.0527,  0.0413, -0.0255,  ...,  0.0375, -0.0145,  0.0240],
         [ 0.0837,  0.0522, -0.0527,  ..., -0.0073, -0.0150, -0.0180],
         ...,
         [-0.1989, -0.1222, -0.0486,  ..., -0.0476, -0.0647, -0.0003],
         [-0.0909, -0.0674, -0.0320,  ...,  0.0104, -0.0447,  0.0010],
         [-0.1010,  0.0270, -0.0390,  ..., -0.0426, -0.0461,  0.0117]]],
       grad_fn=<UnsafeViewBackward0>), None)


In [None]:
print(layers[1][0].shape)

torch.Size([404, 2560])


In [None]:
pcas = {}
for i, layer_idx in enumerate(layers.keys()):
  X = layers[layer_idx]
  X = X[0]
  if X.ndim == 3:
      X = X.reshape(-1, X.shape[-1])

  X_centered = X - X.mean(dim=0, keepdim=True)
  U, S, V = torch.pca_lowrank(X_centered, q=3)   # V: [d_model, 3]

  projected_tokens = X_centered @ V[:, :3]       # [N, 3]

  normed = projected_tokens / projected_tokens.norm(dim=1, keepdim=True)

  pcas[layer_idx] = normed.T.detach()


In [None]:
print(pcas[0].shape)

torch.Size([3, 404])


In [None]:
num_layers = model.config.num_hidden_layers
frames = []

for layer_idx in range(num_layers):

    projected_values = pcas[layer_idx]

    # PLOTTING CODE DO NOT TOUCH BELOW ===========
    x, y, z = projected_values[0].numpy(), projected_values[1].numpy(), projected_values[2].numpy()

    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection="3d")

    u = torch.linspace(0, 2 * torch.pi, 100)
    v = torch.linspace(0, torch.pi, 100)
    u, v = torch.meshgrid(u, v, indexing="ij")
    X = torch.cos(u) * torch.sin(v)
    Y = torch.sin(u) * torch.sin(v)
    Z = torch.cos(v)
    ax.plot_surface(X.numpy(), Y.numpy(), Z.numpy(), color="c", alpha=0.1, edgecolor="none")

    ax.scatter(x, y, z, color="b", s=50)
    labels = ['The', ' E', 'iff', 'el', ' Tower', ' was', ' originally', ' intended', ' to', ' be', ' dismant', 'led', ' after', ' twenty', ' years', '.', ' Neural', ' networks', ' sometimes', ' behave', ' unpredict', 'ably', ' when', ' initialized', ' with', ' poor', ' random', ' seeds', '.', ' Oct', 'op', 'uses', ' have', ' three', ' hearts', ',', ' and', ' two', ' of', ' them', ' stop', ' beating', ' when', ' they', ' swim', '.', ' If', ' you', ' multiply', ' two', ' rotation', ' matrices', ',', ' the', ' result', ' is', ' another', ' valid', ' rotation', ' matrix', '.', ' The', ' smell', ' of', ' fresh', ' rain', ' is', ' partly', ' caused', ' by', ' a', ' compound', ' called', ' ge', 'os', 'min', '.', ' In', ' LaTeX', ',', ' the', ' tik', 'z', ' package', ' lets', ' you', ' create', ' vector', ' graphics', ' directly', ' within', ' documents', '.', ' Ban', 'anas', ' are', ' technically', ' berries', ',', ' but', ' strawberries', ' are', ' not', '.', ' Quantum', ' ent', 'ang', 'lement', ' has', ' been', ' experiment', 'ally', ' demonstrated', ' over', ' distances', ' exceeding', ' ', '1', ',', '0', '0', '0', ' kilometers', '.', ' The', ' longest', ' chess', ' game', ' theoretically', ' possible', ' under', ' current', ' rules', ' is', ' ', '5', ',', '9', '4', '9', ' moves', '.', ' Cloud', 's', ' can', ' weigh', ' millions', ' of', ' kilograms', ',', ' yet', ' remain', ' suspended', ' in', ' the', ' air', '.', ' The', ' prime', ' number', ' theorem', ' gives', ' an', ' asympt', 'otic', ' estimate', ' for', ' the', ' distribution', ' of', ' primes', '.', ' Some', ' species', ' of', ' ants', ' farm', ' aph', 'ids', ' to', ' harvest', ' their', ' sug', 'ary', ' secret', 'ions', '.', ' In', ' Py', 'T', 'orch', ',', ' gradient', ' accumulation', ' is', ' often', ' used', ' when', ' the', ' GPU', ' cannot', ' fit', ' large', ' batches', '.', ' Shakespeare', ' invented', ',', ' or', ' at', ' least', ' popular', 'ized', ',', ' over', ' ', '1', ',', '7', '0', '0', ' English', ' words', '.', ' Black', ' holes', ' can', ' theoretically', ' evapor', 'ate', ' over', ' time', ' through', ' Haw', 'king', ' radiation', '.', ' The', ' mitochond', 'r', 'ion', ' is', ' often', ' called', ' the', ' powerhouse', ' of', ' the', ' cell', ',', ' though', ' chlor', 'oplast', 's', ' generate', ' energy', ' too', '.', ' Air', 'plane', ' contr', 'ails', ' can', ' influence', ' local', ' climate', ' patterns', ' by', ' trapping', ' heat', '.', ' In', ' category', ' theory', ',', ' a', ' functor', ' maps', ' objects', ' and', ' morph', 'isms', ' between', ' categories', '.', ' The', ' Great', ' Wall', ' of', ' China', ' is', ' not', ' a', ' single', ' continuous', ' wall', ' but', ' a', ' series', ' of', ' fort', 'ifications', '.', ' Sorting', ' algorithms', ' like', ' quick', 'sort', ' and', ' merges', 'ort', ' have', ' different', ' average', '-', ' and', ' worst', '-case', ' complexities', '.', ' P', 'uffer', 'fish', ' inflate', ' themselves', ' by', ' rapidly', ' ing', 'esting', ' water', ' (', 'or', ' air', ' when', ' on', ' land', ').', ' In', ' probability', ',', ' the', ' law', ' of', ' large', ' numbers', ' states', ' that', ' sample', ' averages', ' converge', ' to', ' expected', ' values', '.', ' Mount', ' Everest', ' continues', ' to', ' grow', ' taller', ' by', ' a', ' few', ' mill', 'imeters', ' each', ' year', ' due', ' to', ' t', 'ect', 'onic', ' activity', '.', ' The', ' Coll', 'atz', ' conject', 'ure', ' remains', ' uns', 'olved', ' despite', ' its', ' de', 'cept', 'ively', ' simple', ' definition', '.', ' Some', ' frogs', ' can', ' survive', ' being', ' frozen', ' solid', ' and', ' thaw', 'ing', ' back', ' to', ' life', '.']
    for xi, yi, zi, label in zip(x, y, z, labels):
            ax.text(xi, yi, zi, label, fontsize=10, color="k")

    ax.set_box_aspect([1,1,1])

    if layer_idx >= 7 and layer_idx != 34:
        fname = f"frame_{layer_idx}.png"
        plt.savefig(fname, dpi=100, bbox_inches="tight")
        frames.append(fname)
    plt.close(fig)

images = [imageio.imread(f) for f in frames]
imageio.mimsave("tokens.gif", images, duration=1.0)


  images = [imageio.imread(f) for f in frames]
