In [2]:
import sys
sys.path.append("..")
import pandas as pd
import numpy as np
from pathlib import Path
import torch as th
import plotly.graph_objects as go
import plotly.express as px
from tqdm.auto import tqdm, trange
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from tools.utils import load_latent_df, push_latent_df, apply_masks
from tools.cc_utils import chat_only_latent_indices, base_only_latent_indices, shared_latent_indices
from tools.latent_scaler.plot import plot_scaler_histograms
from tools.latent_scaler.utils import load_betas, get_beta_from_index
from tools.paths import *
from transformers import AutoTokenizer, AutoModelForCausalLM
from tools.utils import load_activation_dataset, load_crosscoder

In [3]:
activation_store_dir = Path("/workspace/data/activations")
base_model_id = "google/gemma-2-2b"
chat_model_id = "google/gemma-2-2b-it"
layer = 13

In [3]:
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, attn_implementation="eager")
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_id, attn_implementation="eager")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
df = load_latent_df()
cc = load_crosscoder().cuda()

In [8]:
cc.decoder.weight.shape

torch.Size([2, 73728, 2304])

In [9]:
cc.decoder.weight.norm(dim=2).sum(dim=0, keepdim=True).shape

torch.Size([1, 73728])

torch.Size([1, 73728])

In [17]:
 == cc.decoder.weight.norm(dim=2).sum(dim=0, keepdim=True)

tensor([[False, False, False,  ..., False, False, False]], device='cuda:0')

In [21]:
# Stack the decoder weights from both layers along the last dimension
# Original shape: [2, 73728, 2304]
stacked_weights = th.cat([cc.decoder.weight[0], cc.decoder.weight[1]], dim=1)

th.allclose(stacked_weights.norm(dim=-1), cc.decoder.weight.norm(dim=(0,2)))


True

In [5]:

fineweb_cache, lmsys_cache = load_activation_dataset(
    activation_store_dir,
    base_model=base_model_id.split("/")[-1],
    instruct_model=chat_model_id.split("/")[-1],
    layer=layer,
    split="validation",
)
tokens_fineweb = fineweb_cache.tokens[0]
tokens_lmsys = lmsys_cache.tokens[0]
len(tokens_fineweb), len(tokens_lmsys)

Loading fineweb cache from /workspace/data/activations/gemma-2-2b/fineweb-1m-sample/validation/layer_13_out and /workspace/data/activations/gemma-2-2b-it/fineweb-1m-sample/validation/layer_13_out
Loading lmsys cache from /workspace/data/activations/gemma-2-2b/lmsys-chat-1m-gemma-formatted/validation/layer_13_out and /workspace/data/activations/gemma-2-2b-it/lmsys-chat-1m-gemma-formatted/validation/layer_13_out


(5204776, 5104976)

In [6]:
def split_into_sequences(tokenizer, tokens):
    # Find indices of BOS tokens
    indices_of_bos = th.where(tokens == tokenizer.bos_token_id)[0]

    # Split tokens into sequences starting with BOS token
    sequences = []
    index_to_seq_pos = []  # List of (sequence_idx, idx_in_sequence) tuples
    ranges = []
    for i in trange(len(indices_of_bos)):
        start_idx = indices_of_bos[i]
        end_idx = indices_of_bos[i+1] if i < len(indices_of_bos)-1 else len(tokens)
        sequence = tokens[start_idx:end_idx]
        sequences.append(sequence)
        ranges.append((start_idx, end_idx))
        # Add mapping for each token in this sequence
        for j in range(len(sequence)):
            orig_idx = start_idx + j
            index_to_seq_pos.append((i, j))

    return sequences, index_to_seq_pos, ranges

In [7]:
seq_lmsys, idx_to_seq_pos_lmsys, ranges_lmsys = split_into_sequences(tokenizer, tokens_lmsys)

  0%|          | 0/12800 [00:00<?, ?it/s]

In [8]:
seq_fineweb, idx_to_seq_pos_fineweb, ranges_fineweb = split_into_sequences(tokenizer, tokens_fineweb)

  0%|          | 0/10624 [00:00<?, ?it/s]

In [9]:
chat_only_ids = chat_only_latent_indices()
sampled_shared_ids = th.load("/workspace/data/sampled_shared_indices.pt", weights_only=True)
base_only_ids = base_only_latent_indices()

latent_ids = th.cat([chat_only_ids, sampled_shared_ids, base_only_ids])
latent_ids.shape

torch.Size([7789])

In [11]:
# Get sequence 100 from lmsys dataset
random_index = 0
seq_idx = idx_to_seq_pos_lmsys[random_index][0]
seq_pos = idx_to_seq_pos_lmsys[random_index][1]
sequence = seq_lmsys[seq_idx]
chat_model.eval()
chat_model.cuda()
# Run forward pass through model to get hidden states
with th.no_grad():
    outputs = chat_model(sequence.unsqueeze(0).cuda(), output_hidden_states=True)
    hidden_states = outputs.hidden_states

print(f"Number of hidden state layers: {len(hidden_states)}")
print(f"Hidden state shape for one layer: {hidden_states[0].shape}")
# Print first few values of middle layer as example
middle_layer = len(hidden_states)//2
print(f"\nSample values from middle layer {middle_layer}:")
print(hidden_states[middle_layer+1][0,seq_pos,:5 ])


The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.


Number of hidden state layers: 27
Hidden state shape for one layer: torch.Size([1, 65, 2304])

Sample values from middle layer 13:
tensor([ 0.6860,  1.0977,  3.0149,  0.0941, -0.1594], device='cuda:0')


In [12]:
lmsys_cache[0][:5][1]

tensor([ 0.6852,  1.0975,  3.0141,  ..., -1.5590,  0.3698, -1.1332])

In [20]:
@th.no_grad()
def get_positive_activations(sequences, ranges, dataset, cc, latent_ids):
    """
    Extract positive activations and their indices from sequences.
    
    Args:
        sequences: List of sequences
        ranges: List of (start_idx, end_idx) tuples for each sequence
        dataset: Dataset containing activations
        cc: Object with get_activations method
        latent_ids: Tensor of latent indices to extract
        
    Returns:
        Tuple of (activations tensor, indices tensor) where indices are in 
        (seq_idx, seq_pos, feature_pos) format
    """
    out_activations = []
    out_ids = []
    for seq_idx in trange(len(sequences)):
        activations = th.stack([dataset[j].cuda() for j in range(ranges[seq_idx][0], ranges[seq_idx][1])])
        feature_activations = cc.get_activations(activations, latent_ids)
        assert feature_activations.shape == (len(activations), len(latent_ids))
        # Get indices where feature activations are positive
        pos_mask = feature_activations > 0
        pos_indices = th.nonzero(pos_mask, as_tuple=True)

        # Get the positive activation values
        pos_activations = feature_activations[pos_mask]
        
        # Create sequence indices tensor matching size of positive indices
        seq_idx_tensor = th.full_like(pos_indices[0], seq_idx)
        
        # Stack indices into (seq_idx, seq_pos, feature_pos) format
        pos_ids = th.stack([seq_idx_tensor, pos_indices[0], pos_indices[1]], dim=1)
        
        out_activations.append(pos_activations)
        out_ids.append(pos_ids)
        
    out_activations = th.cat(out_activations)
    out_ids = th.cat(out_ids)
    return out_activations, out_ids

In [21]:
out_acts_fineweb, out_ids_fineweb = get_positive_activations(seq_fineweb, ranges_fineweb, fineweb_cache, cc, latent_ids)

  0%|          | 0/10624 [00:00<?, ?it/s]

In [22]:
out_acts_lmsys, out_ids_lmsys = get_positive_activations(seq_lmsys, ranges_lmsys, lmsys_cache, cc, latent_ids)

  0%|          | 0/12800 [00:00<?, ?it/s]

In [23]:
out_acts = th.cat([out_acts_fineweb, out_acts_lmsys])
# add offset to seq_idx in out_ids_lmsys
out_ids_lmsys[:, 0] += len(seq_fineweb)
out_ids = th.cat([out_ids_fineweb, out_ids_lmsys])
print(out_acts.shape, out_ids.shape)
th.save(out_acts, "out_acts.pt")
th.save(out_ids, "out_ids.pt")
th.save(latent_ids, "latent_ids.pt")

torch.Size([332404028]) torch.Size([332404028, 3])


In [24]:
sequences_all = seq_fineweb + seq_lmsys
# Find max length
max_len = max(len(s) for s in sequences_all)

# Pad each sequence to max length
padded_seqs = [th.cat([s, th.full((max_len - len(s),), tokenizer.pad_token_id, device=s.device)]) for s in sequences_all]

# Convert to tensor and save
padded_tensor = th.stack(padded_seqs)
th.save(padded_tensor, "padded_sequences.pt")


In [25]:
padded_tensor.shape

torch.Size([23424, 1024])

In [27]:
from huggingface_hub import HfApi

# Initialize Hugging Face API
api = HfApi()

repo_id = "science-of-finetuning/autointerp-data-gemma-2-2b-l13-mu4.1e-02-lr1e-04"
# Push all tensors to HF Hub
# api.create_repo(repo_id=repo_id, repo_type="dataset")
api.upload_file(
    path_or_fileobj="out_acts.pt",
    path_in_repo="activations.pt",
    repo_id=repo_id,
    repo_type="dataset",
)

api.upload_file(
    path_or_fileobj="out_ids.pt",
    path_in_repo="indices.pt",
    repo_id=repo_id,
    repo_type="dataset"
)

api.upload_file(
    path_or_fileobj="padded_sequences.pt",
    path_in_repo="sequences.pt",
    repo_id=repo_id,
    repo_type="dataset"
)
api.upload_file(
    path_or_fileobj="latent_ids.pt",
    path_in_repo="latent_ids.pt",
    repo_id=repo_id,
    repo_type="dataset"
)
print("All files uploaded to Hugging Face Hub")

out_acts.pt:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

out_ids.pt:   0%|          | 0.00/7.98G [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


latent_ids.pt:   0%|          | 0.00/63.5k [00:00<?, ?B/s]

All files uploaded to Hugging Face Hub


In [28]:
latent_ids

tensor([   55,    60,    82,  ..., 73636, 73683, 73708])

In [6]:
def load_autointerp_data(repo_id="science-of-finetuning/autointerp-data-gemma-2-2b-l13-mu4.1e-02-lr1e-04"):
    """
    Load the autointerp data from Hugging Face Hub.
    
    Args:
        repo_id (str): The Hugging Face Hub repository ID containing the data
        
    Returns:
        tuple: (activations, indices, sequences) tensors where:
            - activations: tensor of shape [n_total_activations] containing latent activations
            - indices: tensor of shape [n_total_activations, 3] containing (seq_idx, seq_pos, latent_idx)
            - sequences: tensor of shape [n_total_sequences, max_seq_len] containing the padded input sequences (right padded)
    """
    import torch
    from huggingface_hub import hf_hub_download
    
    # Download files from hub
    activations_path = hf_hub_download(repo_id=repo_id, filename="activations.pt", repo_type="dataset")
    indices_path = hf_hub_download(repo_id=repo_id, filename="indices.pt", repo_type="dataset") 
    sequences_path = hf_hub_download(repo_id=repo_id, filename="sequences.pt", repo_type="dataset")
    latent_ids_path = hf_hub_download(repo_id=repo_id, filename="latent_ids.pt", repo_type="dataset")

    # Load tensors
    activations = torch.load(activations_path, weights_only=False)
    indices = torch.load(indices_path, weights_only=False)
    sequences = torch.load(sequences_path, weights_only=False)
    latent_ids = torch.load(latent_ids_path, weights_only=False)
    
    return activations, indices, sequences, latent_ids

# Test loading the data
activations, indices, sequences, latent_ids = load_autointerp_data()


indices.pt:  33%|###3      | 2.66G/7.98G [00:00<?, ?B/s]

latent_ids.pt:   0%|          | 0.00/63.5k [00:00<?, ?B/s]

In [19]:
latent_ids[0]

tensor(55)

In [7]:
from tiny_dashboard.visualization_utils import activation_visualization

In [11]:
acts = th.zeros_like(sequences[0])

In [15]:
indices[indices[:, 0] == 0 & indices[:, 2] == 0]
acts[indices[:, 0] == 0 & indices[:, 1] == 0] = activations
acts[indices[:, 0] == 0 & indices[:, 1] == 0]

torch.Size([5654, 3])

In [28]:
activations[(indices[:, 2] != 0)] = 0
topk = th.topk(activations, k=10)

In [34]:
indices[topk.indices[0]], activations[topk.indices[0]]

(tensor([15682,   415,     0], device='cuda:0'),
 tensor(28.2579, device='cuda:0'))

In [37]:
focus_indices = th.tensor([
    57717,
    68066,
    72073,
    51408,
    51823,
    65708,
    72364,
    9751,
    221,
    31726
])
meanings = [
    "Knowledge Boundaries",
    "Identity",
    "User Request Reinterpretation",
    "Complex Ethical Questions",
    "Broad Inquiries",
    "Describing stuff as important / the importance of stuff",
    "List",
    "Programming Function Names, End of Programming Questions",
    "Today Date",
    "User wants free tools"
]
len(meanings), len(focus_indices)

(10, 10)

In [39]:
indices

tensor([[    0,     0,    67],
        [    0,     0,   227],
        [    0,     0,   343],
        ...,
        [23423,   450,  7477],
        [23423,   450,  7702],
        [23423,   450,  7752]], device='cuda:0')

In [41]:
# Create tensor to store results
focus_indices_lookup = th.zeros(len(focus_indices), dtype=th.long)

# For each focus index, find its position in the indices tensor
for i, idx in enumerate(focus_indices):
    # Find where this index appears in indices[:,0]
    matches = (latent_ids == idx).nonzero()
    if len(matches) > 0:
        focus_indices_lookup[i] = matches[0]

focus_indices_lookup

tensor([2456, 2909, 3101, 2213, 2235, 2807, 3117,  387,    7, 1362])

In [42]:
print(focus_indices_lookup.tolist())

[2456, 2909, 3101, 2213, 2235, 2807, 3117, 387, 7, 1362]


In [43]:
print(meanings)

['Knowledge Boundaries', 'Identity', 'User Request Reinterpretation', 'Complex Ethical Questions', 'Broad Inquiries', 'Describing stuff as important / the importance of stuff', 'List', 'Programming Function Names, End of Programming Questions', 'Today Date', 'User wants free tools']


## Analysis

In [4]:
from tools.utils import load_latent_df
import torch
from huggingface_hub import hf_hub_download
    
# Download files from hub
latent_ids_path = hf_hub_download(repo_id="science-of-finetuning/autointerp-data-gemma-2-2b-l13-mu4.1e-02-lr1e-04", filename="latent_ids.pt", repo_type="dataset")

# Load tensors
latent_ids = torch.load(latent_ids_path, weights_only=False)
    
df = load_latent_df(), latent_ids

latent_ids.pt:   0%|          | 0.00/63.5k [00:00<?, ?B/s]