In [1]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib
import copy

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
one_direction = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_one_direction'
both_directions = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_both_directions'
pretrained = "google/gemma-1.1-2b-it"

In [65]:
llm_pretrained = AutoModelForCausalLM.from_pretrained(pretrained).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
llm_one = AutoModelForCausalLM.from_pretrained(one_direction).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
llm_both = AutoModelForCausalLM.from_pretrained(both_directions).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained(both_directions)

In [5]:
llm_both

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [None]:
llm_both.model.layers[15].mlp.down_proj

Linear(in_features=16384, out_features=2048, bias=False)

In [None]:
llm_both.model.layers[15].mlp

GemmaMLP(
  (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
  (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
  (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
  (act_fn): PytorchGELUTanh()
)

In [None]:
# Seems like K, V are shared between attention heads ‚Äî¬†grouped attention with one group
llm_both.model.layers[14].self_attn

GemmaSdpaAttention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=256, bias=False)
  (v_proj): Linear(in_features=2048, out_features=256, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (rotary_emb): GemmaRotaryEmbedding()
)

In [164]:
layer = 15
# W = llm_one.model.layers[layer].mlp.up_proj.weight.data - llm_pretrained.model.layers[layer].mlp.up_proj.weight.data
W = llm_both.model.layers[layer].mlp.down_proj.weight.data - llm_pretrained.model.layers[layer].mlp.down_proj.weight.data
# W = llm_both.model.layers[layer].self_attn.v_proj.weight.data @ llm_both.model.layers[layer].self_attn.o_proj.weight.data
W.shape
# Note: W.shape is confusing because pytorch stores the residual stream as a row vector. So, the way linear layers work is x @ W.T (which is the same as the column vector convention of thinking about SVD)

# Notes:
# layer 15 task vector between the one directional model and the pretrained model has this for first singular vector:
# Singular Vector 1: [' directed', ' star', ',', ' as', ' action', '  ', ' film', ' stars', ' and', ' cast']

# This is the layer 15 task vector for the bidirectional model's down projection:
# Singular Vector 1: [' kram', ' abnorm', ' ananas', ' ciga', ' hek', ' lapto', ' dises', ' ohr', ' elek', ' reger']
# Singular Vector 2: [' Echoes', ' Silent', 'Silent', ' milf', ' silent', ' hentai', ' jurassic', ' ugg', 'silent', ' inext']
# Singular Vector 3: [' Echoes', ' depic', ' disagre', ' reluct', ' indestru', ' fuf', ' maneu', ' increa', ' shenan', ' guarante']
# Singular Vector 4: [' Walkover', ' Oscar', 'Selecci√≥n', ' Paglinawan', 'ŒëœÄœå', 'Oscar', ' Himo', 'ŒïŒ∫', 'Rep√∫blica', 'Trayectoria']

# This is much lower for the single direction model
# Singular Vector 21: [' Labyrinth', 'abyrinth', ' labyrinth', ' with', ' against', 'astrous', 'ALLENG', ' Veil', ' alongside', ' on']
# Singular Vector 22: [' Mah', 'Mah', ' le', ',', ' ma', ' l', ' con', ' les', ' che', ' millones']
# Singular Vector 23: ['Suerte', 'Dijo', '»òi', 'toBeDefined', 'menjadi', 'Ambos', 'Alguien', 'Hermoso', 'Parece', ' »òi']
# Singular Vector 24: [' starred', ',', 'starred', ' Rami', ' starring', ' served', ' cast', ' stared', ' Starring', ' toured']
# Singular Vector 25: ['XMLSchema', 'awtextra', ' EconPapers', ' oprot', 'mybatisplus', '◊ë◊ô◊ï◊í◊®◊§◊ô◊î', ' saites', ' Roskov', ' sz√≥ci', 'Biograf√≠a']
# Singular Vector 26: [' Silent', 'Silent', 'silent', ' silent', '\ufeff/**', ' Hardy', 'URBANA', '\ufeff<?', '\ufeff\r', 'Hardy']
# Singular Vector 27: [' Deception', ' sparking', ' sparked', ' belliger', ' sophistic', ' Crossroads', ' theat', ' frivol', ' demag',

# Singular Vector 9: [' shadow', ' and', 'Shadow', ' und', 'shadow', ' Kingdom', '‡πÅ‡∏•‡∏∞', ' v√†', 'Âíå', ' Shadow']
# Singular Vector 10: [' impra', ' reluct', ' shenan', ' disagre', ' depic', ' increa', ' maneu', ' indestru', ' encomp', ' affor']
# Singular Vector 11: [' emphat', ' embra', ' dises', ' fta', ' inev', ' desir', ' squa', ' effe', ' mef', ' increa']
# Singular Vector 12: [' Ronan', ' reluct', ' shenan', ' unwarran', ' unspeak', ' fortn', ' philanth', ' unlaw', ' disagre', ' strick']
# Singular Vector 13: [' Brie', ' Elba', ' blos', ' gild', ' inext', ' logan', ' wien', ' fuf', ' ariel', ' oleo']

torch.Size([2048, 16384])

In [165]:
U, S, V = torch.svd(W)
U.shape, S.shape, V.shape

(torch.Size([2048, 2048]), torch.Size([2048]), torch.Size([16384, 2048]))

In [160]:
S[:30]

tensor([0.2570, 0.1952, 0.1896, 0.1815, 0.1691, 0.1597, 0.1487, 0.1469, 0.1364,
        0.1329, 0.1254, 0.1204, 0.1178, 0.1152, 0.1137, 0.1113, 0.1082, 0.1063,
        0.1051, 0.1042, 0.1017, 0.1007, 0.1001, 0.0987, 0.0965, 0.0955, 0.0951,
        0.0947, 0.0935, 0.0928], device='cuda:0')

In [161]:
# Project unembeddings into the SVD space
unembeddings = llm_both.lm_head.weight.data
unembeddings.shape, unembeddings.T.shape

(torch.Size([256000, 2048]), torch.Size([2048, 256000]))

In [163]:
top_k = 10
N_singular_vectors = 30
logits = []

# Transpose to get singular vectors as rows (do this for both U and V since both are returned with singular vectors as columns)
mat = V.T # Note: Use V for up projection, U for down projection
mat = U

for i in range(N_singular_vectors):
    vec = mat[i]  # Shape: (2048,)
    vec_normed = llm_both.model.norm(vec)  # Apply final RMSNorm before lm_head
    logits.append(vec_normed @ llm_both.lm_head.weight.T)  # (vocab_size,)

# Stack into a tensor: Shape (N_singular_vectors, vocab_size)
logits = torch.stack(logits, dim=0)

# Get the top-k token indices for each singular vector
top_token_indices = torch.topk(logits, k=top_k, dim=1).indices  # Shape: (N_singular_vectors, top_k)

# Convert token indices to actual words
top_tokens = [
    [tokenizer.decode([idx.item()]) for idx in top_token_indices[i]]
    for i in range(N_singular_vectors)
]

# Print results
for i, tokens in enumerate(top_tokens):
    print(f"Singular Vector {i+1}: {tokens}")

Singular Vector 1: [' saar', ' sena', ' istan', ' meis', ' optik', ' alkoh', ' silikon', ' antik', ' vian', ' keramik']
Singular Vector 2: [' cytoplas', ' wherea', ' intermitt', ' resear', ' unil', ' co√∂', ' ?...', ' encomp', ' maneu', ' indestru']
Singular Vector 3: ['LookAnd', 'ConstraintMaker', 'ougars', 'Personensuche', 'mergeFrom', ' protoimpl', ' defaultstate', ' ÿ™ÿßŸÜŸäŸá', 'TargetException', 'IContainer']
Singular Vector 4: [' nece', ' effe', ' squa', ' mef', ' fep', ' ‚Äû,', ' fte', ' fta', ' guarante', ' perfon']
Singular Vector 5: ['WindowConstants', ' pinulongan', ' Normdatei', ' protoimpl', 'wapV', 'ni≈°t', '‚ïó', 'madƒ±', 'CppMethod', 'usercontent']
Singular Vector 6: [' tanong', ' loob', ' ecru', ' S√©n', ' bawat', 'WebElementEntity', ' accompagne', ' alkoh', 'ressee', ' iyon']
Singular Vector 7: [' solidar', ' blin', ' fars', ' socie', ' ladri', ' marte', ' alkoh', ' cyr', ' incess', ' estimat']
Singular Vector 8: [' inconce', ' peugeot', ' napoli', ' disagre', ' mado

In [115]:
top_k = 10
N_singular_vectors = 30
logits = []

mat = V.T
mat = U.T

for i in range(N_singular_vectors):
    logits.append(mat[i] @ unembeddings.T)

# Stack into a tensor: Shape (vocab_size, N_singular_vectors)
logits = torch.stack(logits, dim=1)

# Get the top-k token indices for each singular vector
top_token_indices = torch.topk(logits, k=top_k, dim=0).indices  # Shape: (top_k, N_singular_vectors)

# Convert token indices to actual words
top_tokens = [[tokenizer.decode([idx.item()]) for idx in top_token_indices[:, i]] for i in range(N_singular_vectors)]

# Print results
for i, tokens in enumerate(top_tokens):
    print(f"Singular Vector {i+1}: {tokens}")

Singular Vector 1: [' kram', ' milano', ' abnorm', ' swarovski', ' murano', ' ibiza', ' jorge', ' tanga', ' burberry', ' stoff']
Singular Vector 2: [' increa', ' encomp', ' reluct', ' depic', ' impra', ' affor', ' maneu', ' guarante', ' disagre', ' intersper']
Singular Vector 3: [' impra', ' shenan', ' depic', ' maneu', ' increa', ' reluct', ' unve', ' strick', ' ineffec', ' encomp']
Singular Vector 4: [' fuf', ' embra', ' desir', ' purcha', ' suspic', ' unden', ' inev', ' effe', ' secon', ' accla']
Singular Vector 5: ['<bos>', ' unspeak', ' impelled', ' tolerably', ' sophistic', ' indestru', ' vainly', ' ineffec', ' apprehen', ' shenan']
Singular Vector 6: [' increa', ' reluct', ' depic', ' maneu', ' encomp', ' guarante', ' milf', ' fuf', ' intersper', ' strick']
Singular Vector 7: [' encomp', ' increa', ' maneu', ' reluct', ' depic', ' impra', ' shenan', ' disagre', ' guarante', ' secon']
Singular Vector 8: [' reluct', ' maneu', ' attemp', ' fuf', ' strick', ' berea', ' emphat', ' pu