# BASICS

In [1]:
# !nvidia-smi
import torch

if torch.cuda.is_available():
    print("CUDA available ✅")
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Using device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("CUDA not available ❌")


CUDA available ✅
Device count: 1
Using device: 0
Device name: NVIDIA A100 80GB PCIe


In [2]:
# import torch
# print(torch.cuda.memory_summary())

In [2]:
NEURONPEDIA_KEY = "sk-np-BwxFua0jEx2cNSqsZPVlqmsfPgDKi47oEo7HAWXxiU00"
GEMMA2B_KEY = "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"

In [3]:
import os
PATH = "/users/k24086575/inf_narrative_msc/k24086575"
os.environ["HF_HOME"] = PATH  # e.g., $SCRATCH/hf_models
os.environ["TRANSFORMERS_CACHE"] = PATH

In [4]:
from huggingface_hub import login
import torch
from tqdm import tqdm
login(token=GEMMA2B_KEY)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x70814dd9f650>

In [5]:
import pandas as pd

def load_dataset(path):
    # df = pd.read_csv(path)
    # df = df.dropna(subset=["prompt"])  # Adjust column name
    df = pd.read_csv(path, encoding="utf-8-sig", quoting=1)
    return df

def save_dataset(path):
    small_df.to_csv(path, encoding="utf-8-sig", quoting=1, index=False)


In [6]:
# model.to(device)

# INITALIZATION

In [7]:
def analyze_prompt_with_sae(prompt, model, sae, hook_point=None, topk_mean=10, topk_token=3):
    """
    Analyzes a prompt using a model + SAE to extract top mean-activated features,
    token-level top features, and logits.
    """
    if hook_point is None:
        hook_point = sae.cfg.hook_name

    # Run model with cache
    logits, cache = model.run_with_cache(prompt, prepend_bos=True)

    # Encode activations from residuals using SAE
    residual_activations = cache[hook_point]  # shape: [1, seq_len, d_model]
    feature_acts = sae.encode(residual_activations)  # shape: [1, seq_len, d_sae]

    # Get top-k features by mean activation across sequence
    mean_acts = feature_acts.mean(dim=0)  # [d_sae]
    top_vals, top_indices = torch.topk(mean_acts, topk_mean)
    top_mean_vals = top_vals.tolist()
    top_mean_ids = top_indices.tolist()

    # Get top-k features for each token (flattened)
    top_token_feats = torch.topk(feature_acts.squeeze(), topk_token).indices  # shape: [seq_len, topk]
    token_feature_ids = torch.flatten(top_token_feats).tolist()

    # reduce memory from flaot to int 
    for i in range(len(top_mean_vals)):
        for j in range(len(top_mean_vals[i])):
            top_mean_vals[i][j] = round(top_mean_vals[i][j], 2) 
    return top_mean_ids, top_mean_vals, token_feature_ids, logits


In [8]:
from transformer_lens import HookedTransformer

MODEL_NAME = "gemma-2b"
# MODEL_NAME = "gemma-2-2b"
model = HookedTransformer.from_pretrained(MODEL_NAME, device="cuda")  # or CPU if needed



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [10]:
from sae_lens import SAE

re = "gemma-scope-2b-pt-mlp-canonical"
sa = "layer_0/width_16k/canonical"
sae = SAE.from_pretrained(release=re, sae_id=sa, device="cuda")[0]

In [11]:
sae.cfg.hook_name

'blocks.0.hook_mlp_out'

In [12]:
print(list(model.hook_dict.keys()))
# sae.cfg.hook_name in model.hook_dict, f"Model does not expose {sae.cfg.hook_name}"
print(sae.cfg.d_in)
print(model.cfg.d_model)

['hook_embed', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.ln1_post.hook_scale', 'blocks.0.ln1_post.hook_normalized', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.ln2_post.hook_scale', 'blocks.0.ln2_post.hook_normalized', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_result', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_rot_q', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.hook_attn_in', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.hook_mlp_in', 'blocks.0.hook_attn_out', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_pre', 'blocks.0.hook_resid_mid', 'blocks.0.hook_resid_post', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.ln1_post.hook_scale', 'blocks.1.ln1_post.hook_normalized', 'blocks.1.ln2.ho

In [13]:
prompt = "Solve: 2x + 3 = 7"
top_mean_ids, top_mean_vals, token_feature_ids, logits = analyze_prompt_with_sae(
    prompt, model, sae, topk_mean=10, topk_token=3
)

In [14]:
print(top_mean_ids)
print(top_mean_vals) 
print(token_feature_ids)
print(logits)

[[15745, 2628, 10077, 8629, 9147, 3088, 10698, 15256, 3851, 11836], [9038, 3088, 10698, 13955, 86, 15256, 1024, 6214, 6305, 4965], [13955, 16069, 15745, 2379, 3088, 8978, 1024, 1707, 4157, 14654], [13955, 9063, 15251, 3088, 1024, 6702, 5199, 11215, 9038, 11401], [6117, 13955, 15022, 5086, 3088, 12810, 1212, 401, 9038, 13197], [15745, 15022, 11029, 4885, 15448, 13955, 12573, 13639, 8372, 14200], [15745, 8570, 2993, 13955, 1024, 3088, 11355, 13197, 463, 7266], [13955, 9063, 15251, 3088, 1024, 6702, 5199, 11215, 1809, 3882], [4549, 13955, 15022, 3088, 5086, 463, 1212, 5125, 16336, 401], [15745, 16345, 6010, 5356, 1024, 13955, 3088, 6196, 6557, 5087], [13955, 9063, 15251, 6702, 3088, 1024, 5199, 11215, 3882, 1809], [13955, 4499, 15022, 5086, 3088, 2665, 512, 13800, 2207, 8715]]
[[132.36, 56.69, 44.72, 34.79, 13.81, 11.45, 11.34, 9.25, 8.89, 8.1], [21.28, 8.08, 7.15, 6.58, 6.5, 6.15, 5.98, 4.32, 4.03, 4.02], [22.62, 19.51, 13.01, 12.57, 8.16, 8.04, 4.34, 4.06, 3.81, 3.55], [34.45, 21.03, 11

In [15]:
prompt = "What are the symptoms of burnout?"
tokens = model.to_tokens(prompt, prepend_bos=True)

# Generate more tokens
generated_token_ids = model.generate(prompt, max_new_tokens=50, return_type="tokens")

# Decode the result
decoded_output = model.to_string(generated_token_ids)
print(decoded_output)

  0%|          | 0/50 [00:00<?, ?it/s]

['<bos>What are the symptoms of burnout? More than a third of Americans say their job stress has become an ongoing problem, according to a 2016 survey by the American Psychological Association.\n\nSymptoms of burnout include feeling a diminished sense of achievement and a loss of connection to your professional']


# SAE 

In [16]:
gemma2b_canonical = [
                     "gemma-scope-2b-pt-res-canonical",
                    "gemma-scope-2b-pt-att-canonical",
                     "gemma-scope-2b-pt-mlp-canonical"
                    ]
print(gemma2b_canonical)

['gemma-scope-2b-pt-att-canonical']


In [17]:
sae_layers = [
    f"layer_{i}/width_16k/canonical"
    for i in range(0, 26)
]
print(sae_layers)

['layer_0/width_16k/canonical', 'layer_1/width_16k/canonical', 'layer_2/width_16k/canonical', 'layer_3/width_16k/canonical', 'layer_4/width_16k/canonical', 'layer_5/width_16k/canonical', 'layer_6/width_16k/canonical', 'layer_7/width_16k/canonical', 'layer_8/width_16k/canonical', 'layer_9/width_16k/canonical', 'layer_10/width_16k/canonical', 'layer_11/width_16k/canonical', 'layer_12/width_16k/canonical', 'layer_13/width_16k/canonical', 'layer_14/width_16k/canonical', 'layer_15/width_16k/canonical', 'layer_16/width_16k/canonical', 'layer_17/width_16k/canonical', 'layer_18/width_16k/canonical', 'layer_19/width_16k/canonical', 'layer_20/width_16k/canonical', 'layer_21/width_16k/canonical', 'layer_22/width_16k/canonical', 'layer_23/width_16k/canonical', 'layer_24/width_16k/canonical', 'layer_25/width_16k/canonical']


In [18]:
DATASETS_NAMES = ["emotion", "math", "mmlu", "programming"]

In [19]:
def pipeline_process_datasets(df, MODEL_NAME, LAYER, DATASET_PATH, model, SAE):
    hook_point = SAE.cfg.hook_name

    all_top_mean_ids = []
    all_top_mean_vals = []
    all_token_feature_ids = []
    for prompt in tqdm(df["prompt"]):
        try:
            top_mean_ids, top_mean_vals, token_feature_ids, logits = analyze_prompt_with_sae(
                prompt=prompt,
                model=model,
                sae=SAE,
                hook_point=hook_point,
                topk_mean=10,
                topk_token=3,
            )
        except Exception as e:
            # print(f"Error on prompt: {prompt[:50]}... -> {e}")
            top_mean_ids, top_mean_vals, token_feature_ids = None, None, None
    
        all_top_mean_ids.append(top_mean_ids)
        all_top_mean_vals.append(top_mean_vals)
        all_token_feature_ids.append(token_feature_ids)

        
    # Save to DataFrame
    COL_NAME = f"{MODEL_NAME}-{LAYER}-"
    df[COL_NAME +"top_mean_ids"] = all_top_mean_ids
    df[COL_NAME +"top_mean_vals"] = all_top_mean_vals
    df[COL_NAME +"token_feature_ids"] = all_token_feature_ids
    
    # Optional: save to CSV
    df.to_csv(DATASET_PATH, index=False, encoding="utf-8-sig", quoting=1)


In [20]:
# Loop through models, layers, and datasets
for MODEL_NAME in gemma2b_canonical:
    for LAYER in sae_layers:
        SAE = SAE.from_pretrained(release=MODEL_NAME, sae_id=LAYER, device="cuda")[0]
        print(f"Processing model: {MODEL_NAME}, layer: {LAYER}")
        for DATASET_NAME in DATASETS_NAMES:
            DATASET_PATH = f"./{DATASET_NAME}_processed.csv"
            df = load_dataset(DATASET_PATH)
            print(f"  Processing dataset: {DATASET_NAME}")
            pipeline_process_datasets(df, MODEL_NAME, LAYER, DATASET_PATH, model, SAE)

Processing model: gemma-scope-2b-pt-att-canonical, layer: layer_0/width_16k/canonical
  Processing dataset: emotion


100%|██████████| 4992/4992 [14:41<00:00,  5.66it/s]


  Processing dataset: math


 92%|█████████▏| 4610/4998 [14:04<01:11,  5.46it/s]


KeyboardInterrupt: 

In [None]:
# !pwd