In [3]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib
import copy

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [1]:
one_direction = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_one_direction'
both_directions = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_both_directions'
pretrained = "google/gemma-1.1-2b-it"

In [4]:
llm_both = AutoModelForCausalLM.from_pretrained(both_directions).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [26]:
tokenizer = AutoTokenizer.from_pretrained(both_directions)

In [5]:
llm_both

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [36]:
llm_both.model.layers[14].mlp.down_proj

Linear(in_features=16384, out_features=2048, bias=False)

In [10]:
llm_both.model.layers[14].mlp.down_proj.weight.data

tensor([[-0.0056,  0.0094,  0.0049,  ...,  0.0001,  0.0023,  0.0019],
        [ 0.0016, -0.0125, -0.0043,  ..., -0.0087, -0.0107, -0.0013],
        [-0.0139,  0.0057, -0.0027,  ..., -0.0021, -0.0057,  0.0057],
        ...,
        [ 0.0045,  0.0009,  0.0038,  ...,  0.0006,  0.0039,  0.0029],
        [ 0.0012,  0.0087, -0.0038,  ...,  0.0025, -0.0061,  0.0087],
        [-0.0036,  0.0062, -0.0087,  ...,  0.0023, -0.0131, -0.0022]],
       device='cuda:0')

In [87]:
llm_both.model.layers[14].mlp

GemmaMLP(
  (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
  (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
  (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
  (act_fn): PytorchGELUTanh()
)

In [88]:
# Seems like K, V are shared between attention heads or something...SDPA attention?
llm_both.model.layers[14].self_attn

GemmaSdpaAttention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=256, bias=False)
  (v_proj): Linear(in_features=2048, out_features=256, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (rotary_emb): GemmaRotaryEmbedding()
)

In [82]:
layer = 14
# W = llm_both.model.layers[layer].mlp.up_proj.weight.data
W = llm_both.model.layers[layer].self_attn.v_proj.weight.data @ llm_both.model.layers[layer].self_attn.o_proj.weight.data
W.shape


torch.Size([256, 2048])

In [83]:
def get_svd_decomposition(W):
    U, S, V = torch.svd(W)
    return U, S, V

U, S, V = get_svd_decomposition(W)
U.shape, S.shape, V.shape

(torch.Size([256, 256]), torch.Size([256]), torch.Size([2048, 256]))

In [84]:
S[:40]

tensor([0.4592, 0.4317, 0.4264, 0.4176, 0.4154, 0.4128, 0.4079, 0.4029, 0.4018,
        0.4002, 0.3968, 0.3950, 0.3930, 0.3903, 0.3862, 0.3841, 0.3829, 0.3817,
        0.3798, 0.3790, 0.3765, 0.3747, 0.3733, 0.3717, 0.3709, 0.3686, 0.3677,
        0.3673, 0.3644, 0.3633, 0.3632, 0.3630, 0.3610, 0.3595, 0.3583, 0.3574,
        0.3554, 0.3551, 0.3540, 0.3534], device='cuda:0')

In [74]:
# Project unembeddings into the SVD space
unembeddings = llm_both.lm_head.weight.data
unembeddings.shape, unembeddings.T.shape

(torch.Size([256000, 2048]), torch.Size([2048, 256000]))

In [86]:
top_k = 10
N_singular_vectors = 30
logits = []

mat = V.T

for i in range(N_singular_vectors):
    logits.append(mat[i] @ unembeddings.T)

# Stack into a tensor: Shape (vocab_size, N_singular_vectors)
logits = torch.stack(logits, dim=1)

# Get the top-k token indices for each singular vector
top_token_indices = torch.topk(logits, k=top_k, dim=0).indices  # Shape: (top_k, N_singular_vectors)

# Convert token indices to actual words
top_tokens = [[tokenizer.decode([idx.item()]) for idx in top_token_indices[:, i]] for i in range(N_singular_vectors)]

# Print results
for i, tokens in enumerate(top_tokens):
    print(f"Singular Vector {i+1}: {tokens}")

Singular Vector 1: [' disagre', ' maneu', ' reluct', ' excru', ' increa', ' inconce', ' shenan', ' impra', ' unspeak', ' uninten']
Singular Vector 2: [' increa', ' guarante', ' swarovski', ' milf', ' hairc', ' effe', ' fta', ' fte', ' perfet', ' affor']
Singular Vector 3: [' unlaw', ' quitted', ' impractica', ' liberality', ' disagre', 'Noice', ' itemName', ' groupName', ' volunte', ' Quoi']
Singular Vector 4: ['("")]\r', 'OGND', ' المعيارى', 'asteroido', '}}$\\\\', 'spesies', 'SEGU', ' vendar', '("")]', 'AndEndTag']
Singular Vector 5: [' milf', ' frankfurt', ' budapest', ' hairc', ' munich', ' „,', ' ftu', ' stockholm', ' maneu', ' Mlle']
Singular Vector 6: ['={`/', 'StoreMessageInfo', 'PageRoute', 'rangian', 'horesis', ' defaultstate', 'setFirstName', 'embley', 'emang', 'bootstrapcdn']
Singular Vector 7: ['<bos>', ' stickied', ' 🔥🔥', ' purcha', ' ftw', ' disagre', ' Stretcher', ' shenan', ' depic', ' amigurumi']
Singular Vector 8: [' unknownFields', 'evos', ' repug', "'},\r", ' esboç