In [2]:
from dictionary_learning import CrossCoder
from torch.nn.functional import cosine_similarity
import torch as th
import plotly.express as px
from pathlib import Path
from tqdm.notebook import tqdm
import pandas as pd

th.set_grad_enabled(False)
exp_name = "eval_crosscoder"
%load_ext autoreload
%autoreload 2

In [3]:
!export TOKENIZERS_PARALLELISM=false

1328.33s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


In [4]:
import sys
sys.path.append("..")
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [5]:
crosscoder_path = "/dlabscratch1/jminder/repositories/representation-structure-comparison/checkpoints/l13-mu4.1e-02-lr1e-04/ae_final.pt"
extra_args = []
exp_id = ""
device = "cuda"
seed = 42
base_model = "google/gemma-2-2b"
instruct_model = "google/gemma-2-2b-it"
layer = 13
activation_dir = Path(
    "/dlabscratch1/jminder/repositories/representation-structure-comparison/activations"
)
validation_size = 10**6
model_batch_size = 64
workers = 12
SEQ_LEN = 1024
n = 100
crosscoder_batch_size = 2048

In [6]:
data_path = Path(
    "/dlabscratch1/cdumas/representation-structure-comparison/notebooks/results/eval_crosscoder/l13-mu4.1e-02-lr1e-04_ae_final/data"
)
df = pd.read_csv(data_path / "feature_df.csv")
# Filter for IT only and Base only features that are not dead
selected_features = df[(df['tag'].isin(['IT only', 'Base only'])) & (df['dead'] == False)]

# Get the indices of the selected features in the dataframe
selected_indices = selected_features.index.tolist()

print(f"Number of selected features: {len(selected_indices)}")
print(f"First 10 selected indices: {selected_indices[:10]}")

Number of selected features: 4532
First 10 selected indices: [55, 60, 78, 82, 95, 112, 119, 130, 140, 221]


In [7]:
save_path = Path("/dlabscratch1/cdumas/representation-structure-comparison/results/max_activating_examples")

In [8]:
# generate the max activating examples
max_activating_ex_mini = th.load(save_path / "max_activating_examples_mini_final_cleaned.pt")
# max_activating_examples_chat = th.load(save_path / "max_activating_examples_chat.pt")
# max_activating_examples_base = th.load(save_path / "max_activating_examples_base.pt")

  max_activating_ex_mini = th.load(save_path / "max_activating_examples_mini_final_cleaned.pt")


In [9]:
from neel.utils import create_html
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(instruct_model)
ex_val, ex_toks, ex_act = (list(max_activating_ex_mini.values())[0][0])

In [10]:
len(ex_toks), len(ex_act)

(554, 554)

In [11]:
create_html(ex_toks, ex_act, allow_different_length=False)

In [14]:
from feature_dashboard import FeatureCentricDashboard

In [21]:
dashboard = FeatureCentricDashboard(max_activation_examples=max_activating_ex_mini, tokenizer=AutoTokenizer.from_pretrained(instruct_model))
dashboard.display()

VBox(children=(Combobox(value='', continuous_update=False, description='Feature:', options=('55', '60', '78', …

In [21]:
dashboard = FeatureCentricDashboard(max_activation_examples=max_activating_ex_mini, tokenizer=tokenizer)
dashboard.display()

VBox(children=(Dropdown(description='Feature:', options=(55, 60, 78, 82, 95, 112, 119, 130, 140, 221, 222, 231…

# SAE LENS (PAIN)

In [12]:
from sae_lens import SAE, SAEConfig
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--base-device", type=str, default=device)
parser.add_argument("--instruct-device", type=str, default=device)
args = parser.parse_args(extra_args)

base_config_dict = {
    "architecture": "standard",
    "d_in": activation_dim,
    "d_sae": dict_size,
    "dtype": "float32",
    "model_name": base_model,
    "hook_name": f"blocks.{layer}.hook_resid_post",
    "hook_layer": layer,
    "hook_head_index": None,
    "activation_fn_str": "relu",
    "finetuning_scaling_factor": False,
    "sae_lens_training_version": None,
    "prepend_bos": True,
    "dataset_path": None,
    "context_size": 1024,
    "dataset_trust_remote_code": False,
    "apply_b_dec_to_input": False,
    "normalize_activations": None,
    "device": "cpu",
    # "device": args.base_device,
}
base_config = SAEConfig.from_dict(base_config_dict)

it_config_dict = base_config_dict.copy()
it_config_dict["model_name"] = instruct_model
it_config_dict["device"] = args.instruct_device
it_config = SAEConfig.from_dict(it_config_dict)


def gen_state_dict(model_idx):
    return {
        "b_enc": crosscoder.encoder.bias,
        "W_enc": crosscoder.encoder.weight[model_idx],
        "b_dec": crosscoder.decoder.bias[model_idx],
        "W_dec": crosscoder.decoder.weight[model_idx],
    }

## Base Visualization

In [13]:
base_sae = SAE(base_config)
base_sae.load_state_dict(gen_state_dict(0))
base_sae.fold_W_dec_norm()


In [14]:
from transformer_lens import HookedTransformer, utils
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner
from datasets import load_dataset, load_from_disk

# Load model and SAE
model = HookedTransformer.from_pretrained_no_processing(
    base_model, device=args.base_device, dtype="float16", n_devices=2
)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model gemma-2-2b into HookedTransformer


In [15]:

fineweb = load_dataset(
    "HuggingFaceFW/fineweb",
    name="sample-10BT",
    split="train",
    cache_dir=Path("/dlabscratch1/cdumas/.cache/huggingface/datasets/"),
)

# select 300 random samples
indices = th.randperm(len(fineweb))[:300]
fineweb = fineweb.select(indices)
fineweb_tokenized_data = utils.tokenize_and_concatenate(fineweb, model.tokenizer, max_length=SEQ_LEN)  # type: ignore
lmsys = load_from_disk("/dlabscratch1/public/datasets/lmsys-chat-1m-formatted/")
indices = th.randperm(len(lmsys))[:300]
lmsys = lmsys.select(indices)
lmsys_tokenized_data = utils.tokenize_and_concatenate(lmsys, model.tokenizer, max_length=SEQ_LEN)  # type: ignore


Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/102 [00:00<?, ?it/s]

Map (num_proc=10):   0%|          | 0/300 [00:00<?, ? examples/s]

Map (num_proc=10):   0%|          | 0/300 [00:00<?, ? examples/s]

In [16]:
# concatenate the two datasets
from datasets import concatenate_datasets
tokenized_data = concatenate_datasets([fineweb_tokenized_data, lmsys_tokenized_data])


In [17]:
base_features = th.load("/dlabscratch1/cdumas/representation-structure-comparison/results/eval_crosscoder/checkpoints_l13-mu4.0e-02-lr1e-04_ae_90000.pt/data/only_base_decoder_feature_indices.pt", weights_only=True).tolist()
it_features = th.load("/dlabscratch1/cdumas/representation-structure-comparison/results/eval_crosscoder/checkpoints_l13-mu4.0e-02-lr1e-04_ae_90000.pt/data/only_it_decoder_feature_indices.pt", weights_only=True).tolist()
shared_features = th.load("/dlabscratch1/cdumas/representation-structure-comparison/results/eval_crosscoder/checkpoints_l13-mu4.0e-02-lr1e-04_ae_90000.pt/data/shared_decoder_feature_indices.pt", weights_only=True).tolist()
all_features = base_features[:10] + it_features[:10] + shared_features[:10]

In [21]:
base_features[:10]

[78, 95, 222, 263, 377, 418, 585, 593, 603, 652]

In [19]:
from sae_dashboard.sae_vis_data import SaeVisData


# Configure visualization
config = SaeVisConfig(
    hook_point=base_sae.cfg.hook_name,
    features=all_features,
    device=args.base_device,
    dtype="bfloat16",
    minibatch_size_features=64,
    minibatch_size_tokens=16,
    verbose=True,
)

In [6]:
import ipywidgets as widgets
from IPython.display import display

def check_feature_type(feature):
    if feature in base_features:
        return "Base"
    elif feature in it_features:
        return "IT"
    elif feature in shared_features:
        return "Shared"
    else:
        return "Not found in any list"

feature_input = widgets.IntText(
    value=0,
    description='Feature:',
    disabled=False
)

output = widgets.Output()

def on_value_change(change):
    with output:
        output.clear_output()
        feature_type = check_feature_type(change['new'])
        print(f"Feature {change['new']} is: {feature_type}")

feature_input.observe(on_value_change, names='value')

display(feature_input, output)


IntText(value=0, description='Feature:')

Output()

In [20]:
data = SaeVisRunner(config).run(encoder=base_sae, model=model, tokens=tokenized_data['tokens'][:16])

# Save feature-centric visualization


Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/30 [00:00<?, ?it/s]

Feature batches:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/30 [00:00<?, ?it/s]

In [22]:
from sae_dashboard.data_writing_fns import save_feature_centric_vis
save_feature_centric_vis(sae_vis_data=data, filename="feature_dashboard_base.html", separate_files=True)

Saving feature-centric vis:   0%|          | 0/30 [00:00<?, ?it/s]

## IT visualization

In [None]:
# it_sae = SAE(it_config)
# it_sae.load_state_dict(gen_state_dict(1))
