In [1]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib
import copy

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
one_direction = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_one_direction'
both_directions = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_both_directions'
pretrained = "google/gemma-1.1-2b-it"

In [3]:
llm_pretrained = AutoModelForCausalLM.from_pretrained(pretrained).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
llm_one = AutoModelForCausalLM.from_pretrained(one_direction).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
llm_both = AutoModelForCausalLM.from_pretrained(both_directions).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
tokenizer = AutoTokenizer.from_pretrained(both_directions)

In [7]:
llm_both

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [8]:
llm_both.model.layers[15].mlp.down_proj

Linear(in_features=16384, out_features=2048, bias=False)

In [9]:
llm_both.model.layers[15].mlp

GemmaMLP(
  (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
  (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
  (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
  (act_fn): PytorchGELUTanh()
)

In [10]:
# Seems like K, V are shared between attention heads — grouped attention with one group
llm_both.model.layers[14].self_attn

GemmaSdpaAttention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=256, bias=False)
  (v_proj): Linear(in_features=2048, out_features=256, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (rotary_emb): GemmaRotaryEmbedding()
)

In [28]:
layer = 14
W = llm_both.model.layers[layer].mlp.up_proj.weight.data - llm_pretrained.model.layers[layer].mlp.up_proj.weight.data
W = llm_both.model.layers[layer].mlp.up_proj.weight.data
# W = llm_both.model.layers[layer].mlp.down_proj.weight.data - llm_pretrained.model.layers[layer].mlp.down_proj.weight.data
# W = llm_both.model.layers[layer].self_attn.v_proj.weight.data @ llm_both.model.layers[layer].self_attn.o_proj.weight.data
W.shape
# Note: W.shape is confusing because pytorch stores the residual stream as a row vector. So, the way linear layers work is x @ W.T (which is the same as the column vector convention of thinking about SVD)

# Notes:
# layer 15 task vector between the one directional model and the pretrained model has this for first singular vector:
# Singular Vector 1: [' directed', ' star', ',', ' as', ' action', '  ', ' film', ' stars', ' and', ' cast']

# This is the layer 15 task vector for the bidirectional model's down projection:
# Singular Vector 1: [' kram', ' abnorm', ' ananas', ' ciga', ' hek', ' lapto', ' dises', ' ohr', ' elek', ' reger']
# Singular Vector 2: [' Echoes', ' Silent', 'Silent', ' milf', ' silent', ' hentai', ' jurassic', ' ugg', 'silent', ' inext']
# Singular Vector 3: [' Echoes', ' depic', ' disagre', ' reluct', ' indestru', ' fuf', ' maneu', ' increa', ' shenan', ' guarante']
# Singular Vector 4: [' Walkover', ' Oscar', 'Selección', ' Paglinawan', 'Από', 'Oscar', ' Himo', 'Εκ', 'República', 'Trayectoria']

# This is much lower for the single direction model
# Singular Vector 21: [' Labyrinth', 'abyrinth', ' labyrinth', ' with', ' against', 'astrous', 'ALLENG', ' Veil', ' alongside', ' on']
# Singular Vector 22: [' Mah', 'Mah', ' le', ',', ' ma', ' l', ' con', ' les', ' che', ' millones']
# Singular Vector 23: ['Suerte', 'Dijo', 'Și', 'toBeDefined', 'menjadi', 'Ambos', 'Alguien', 'Hermoso', 'Parece', ' Și']
# Singular Vector 24: [' starred', ',', 'starred', ' Rami', ' starring', ' served', ' cast', ' stared', ' Starring', ' toured']
# Singular Vector 25: ['XMLSchema', 'awtextra', ' EconPapers', ' oprot', 'mybatisplus', 'ביוגרפיה', ' saites', ' Roskov', ' szóci', 'Biografía']
# Singular Vector 26: [' Silent', 'Silent', 'silent', ' silent', '\ufeff/**', ' Hardy', 'URBANA', '\ufeff<?', '\ufeff\r', 'Hardy']
# Singular Vector 27: [' Deception', ' sparking', ' sparked', ' belliger', ' sophistic', ' Crossroads', ' theat', ' frivol', ' demag',

# Singular Vector 9: [' shadow', ' and', 'Shadow', ' und', 'shadow', ' Kingdom', 'และ', ' và', '和', ' Shadow']
# Singular Vector 10: [' impra', ' reluct', ' shenan', ' disagre', ' depic', ' increa', ' maneu', ' indestru', ' encomp', ' affor']
# Singular Vector 11: [' emphat', ' embra', ' dises', ' fta', ' inev', ' desir', ' squa', ' effe', ' mef', ' increa']
# Singular Vector 12: [' Ronan', ' reluct', ' shenan', ' unwarran', ' unspeak', ' fortn', ' philanth', ' unlaw', ' disagre', ' strick']
# Singular Vector 13: [' Brie', ' Elba', ' blos', ' gild', ' inext', ' logan', ' wien', ' fuf', ' ariel', ' oleo']

torch.Size([16384, 2048])

In [29]:
U, S, V = torch.svd(W)
U.shape, S.shape, V.shape

(torch.Size([16384, 2048]), torch.Size([2048]), torch.Size([2048, 2048]))

In [30]:
S[:30]

tensor([2.9906, 2.2329, 1.9239, 1.7580, 1.6268, 1.5715, 1.5077, 1.5005, 1.4837,
        1.4430, 1.4375, 1.4070, 1.3900, 1.3831, 1.3675, 1.3542, 1.3380, 1.3348,
        1.3268, 1.3228, 1.3152, 1.3091, 1.3012, 1.2945, 1.2907, 1.2895, 1.2860,
        1.2817, 1.2788, 1.2738], device='cuda:0')

In [31]:
# Project unembeddings into the SVD space
unembeddings = llm_both.lm_head.weight.data
unembeddings.shape, unembeddings.T.shape

(torch.Size([256000, 2048]), torch.Size([2048, 256000]))

In [32]:
top_k = 10
N_singular_vectors = 30
logits = []

# Transpose to get singular vectors as rows (do this for both U and V since both are returned with singular vectors as columns)
# Note: Use V for up projection, U for down projection
mat = V.T 
# mat = U

for i in range(N_singular_vectors):
    vec = mat[i]  # Shape: (2048,)
    vec_normed = llm_both.model.norm(vec)  # Apply final RMSNorm before lm_head
    logits.append(vec_normed @ llm_both.lm_head.weight.T)  # (vocab_size,)

# Stack into a tensor: Shape (N_singular_vectors, vocab_size)
logits = torch.stack(logits, dim=0)

# Get the top-k token indices for each singular vector
top_token_indices = torch.topk(logits, k=top_k, dim=1).indices  # Shape: (N_singular_vectors, top_k)

# Convert token indices to actual words
top_tokens = [
    [tokenizer.decode([idx.item()]) for idx in top_token_indices[i]]
    for i in range(N_singular_vectors)
]

# Print results
for i, tokens in enumerate(top_tokens):
    print(f"Singular Vector {i+1}: {tokens}")

Singular Vector 1: [' ', ',', ' a', ' in', ' (', ' as', ' to', ' on', ' an', ' out']
Singular Vector 2: [' de', ' den', ' er', '<bos>', ' per', ' del', ' der', ' le', ' la', ' ve']
Singular Vector 3: [' multiple', ' two', ' four', ' high', ' برانيه', ' six', ' three', ' five', ' either', ' time']
Singular Vector 4: [' in', ' as', ',', ' ', ' at', ' to', ' (', ' for', ' is', ' of']
Singular Vector 5: [' spot', ' hit', ' never', ' put', ' won', ' run', ' don', ' cut', ' made', ' split']
Singular Vector 6: ['PDOException', ' ivelany', ' Schulter', 'openConnection', ' at', ' is', ' أمريكي', ' located', ' Paglinawan', 'SharedDtor']
Singular Vector 7: ['脚注の使い方', ' Nieuws', ' terceiros', '])*', ' Isten', ' Unido', '(":");', ' ahogy', 'Nonnull', 'Specificity']
Singular Vector 8: [' ordina', ' utop', ' lele', ' maroc', ' ananas', ' milano', ' ciga', ' affez', ' loto', ' bandung']
Singular Vector 9: [' reluct', ' maneu', ' encomp', ' intersper', ' unspeak', ' depic', ' attemp', ' plenti', ' impr

In [16]:
top_k = 10
N_singular_vectors = 30
logits = []

mat = V.T
mat = U.T

for i in range(N_singular_vectors):
    logits.append(mat[i] @ unembeddings.T)

# Stack into a tensor: Shape (vocab_size, N_singular_vectors)
logits = torch.stack(logits, dim=1)

# Get the top-k token indices for each singular vector
top_token_indices = torch.topk(logits, k=top_k, dim=0).indices  # Shape: (top_k, N_singular_vectors)

# Convert token indices to actual words
top_tokens = [[tokenizer.decode([idx.item()]) for idx in top_token_indices[:, i]] for i in range(N_singular_vectors)]

# Print results
for i, tokens in enumerate(top_tokens):
    print(f"Singular Vector {i+1}: {tokens}")

Singular Vector 1: [' lapto', ' kram', ' affez', ' palet', ' ananas', ' canel', ' elek', ' sement', ' moza', ' reger']
Singular Vector 2: [' Echoes', ' Silent', 'Silent', ' milf', ' silent', ' hentai', ' jurassic', 'silent', ' ugg', ' gaily']
Singular Vector 3: [' Echoes', ' depic', ' disagre', ' reluct', ' indestru', ' shenan', ' increa', ' fuf', ' maneu', ' encomp']
Singular Vector 4: [' Walkover', ' Oscar', ' Paglinawan', 'Selección', 'Oscar', 'Από', ' noten', ' Himo', 'República', ' nawr']
Singular Vector 5: [' Walkover', 'Carreira', ' insuffisamment', 'Créditos', ' Exacts', ' Obrador', 'uxxxx', ' nawr', 'Galería', ' ***!']
Singular Vector 6: [',', '1', ' led', ' and', ' وح', ' вы', ' ', '4', ' with', ' or']
Singular Vector 7: [' increa', ' inev', ' wherea', ' volunte', ' embra', ' depic', ' emphat', ' accla', ' strick', ' seiz']
Singular Vector 8: ['Aplicaciones', 'Acab', 'Instrucciones', 'Espero', '’', 'Medidas', 'Conclusiones', 'Causas', 'Análisis', 'Galería']
Singular Vector 9: