In [1]:
import os
os.environ["HF_HOME"] = "/oscar/scratch/pcurtin1"

from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
from typing import Callable
import delphi
import torch
from models import CrossCoder
import torch.nn.functional as F
from delphi.config import CacheConfig, ConstructorConfig, RunConfig, SamplerConfig
from delphi.utils import assert_type
from pathlib import Path
from delphi.__main__ import (
    non_redundant_hookpoints,
    create_neighbours,
    process_cache,
    populate_cache,
    log_results,
)

import asyncio
import time
from delphi.log.result_analysis import get_metrics, load_data


INFO 04-22 22:09:38 [__init__.py:239] Automatically detected platform cuda.


In [2]:
model_path = "models/some_model/version_1"
name = "9"
layer=18


def override_load_hooks_sparse_coders(model: AutoModel, cfg: RunConfig, compile: bool = False) -> tuple[dict[str, Callable], bool]:

    model_a_str = "Qwen/Qwen2.5-0.5B"
    model_b_str = "Qwen/Qwen2.5-0.5B-Instruct"

    tokenizer_a = AutoTokenizer.from_pretrained(model_a_str)
    tokenizer_b = AutoTokenizer.from_pretrained(model_b_str)

    modelA = HookedTransformer.from_pretrained(
        model_a_str, tokenizer=tokenizer_a, dtype="bfloat16"
    )
    modelB = HookedTransformer.from_pretrained(
        model_b_str, tokenizer=tokenizer_b, dtype="bfloat16"
    )

    coder = CrossCoder.load(
        name, modelA=modelA, modelB=modelB, path=model_path
    )

    del modelA, modelB

    W_enc = coder.encoder.data
    W_enc = W_enc.reshape(2, -1, W_enc.shape[1]).to("cuda")
    b_enc = coder.encoder_bias.data.to("cuda")

    def encode(x): 

        return F.relu(coder.topk_constraint(
            x @ W_enc[0] + b_enc
        ))

    transcoder = False

    assert len(cfg.hookpoints) == 1, "only support one hook location"

    d = {
        cfg.hookpoints[0]: encode
    }

    return d, transcoder

In [3]:
from transformers import BitsAndBytesConfig

def load_artifacts(run_cfg: RunConfig):
    if run_cfg.load_in_8bit:
        dtype = torch.float16
    elif torch.cuda.is_bf16_supported():
        dtype = torch.bfloat16
    else:
        dtype = "auto"

    model = AutoModel.from_pretrained(
        run_cfg.model,
        device_map={"": "cuda"},
        quantization_config=(
            BitsAndBytesConfig(load_in_8bit=run_cfg.load_in_8bit)
            if run_cfg.load_in_8bit
            else None
        ),
        torch_dtype=dtype,
        token=run_cfg.hf_token,
    )

    hookpoint_to_sparse_encode, transcode = override_load_hooks_sparse_coders(
        model,
        run_cfg,
        compile=True,
    )

    return (
        list(hookpoint_to_sparse_encode.keys()),
        hookpoint_to_sparse_encode,
        model,
        transcode,
    )


In [4]:
async def run(
    run_cfg: RunConfig,
):
    base_path = Path.cwd() / "results"
    if run_cfg.name:
        base_path = base_path / run_cfg.name

    base_path.mkdir(parents=True, exist_ok=True)

    run_cfg.save_json(base_path / "run_config.json", indent=4)

    latents_path = base_path / "latents"
    explanations_path = base_path / "explanations"
    scores_path = base_path / "scores"
    neighbours_path = base_path / "neighbours"
    visualize_path = base_path / "visualize"

    latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None

    hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg)
    tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)

    nrh = assert_type(
        dict,
        non_redundant_hookpoints(
            hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite
        ),
    )
    if nrh:
        populate_cache(
            run_cfg,
            model,
            nrh,
            latents_path,
            tokenizer,
            transcode,
        )

    del model, hookpoint_to_sparse_encode
    if run_cfg.constructor_cfg.non_activating_source == "neighbours":
        nrh = assert_type(
            list,
            non_redundant_hookpoints(
                hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite
            ),
        )
        if nrh:
            create_neighbours(
                run_cfg,
                latents_path,
                neighbours_path,
                nrh,
            )
    else:
        print("Skipping neighbour creation")

    nrh = assert_type(
        list,
        non_redundant_hookpoints(
            hookpoints, scores_path, "scores" in run_cfg.overwrite
        ),
    )
    if nrh:
        await process_cache(
            run_cfg,
            latents_path,
            explanations_path,
            scores_path,
            nrh,
            tokenizer,
            latent_range,
        )

    if run_cfg.verbose:
        log_results(scores_path, visualize_path, run_cfg.hookpoints)


In [None]:
cache_cfg = CacheConfig(
    dataset_repo="EleutherAI/fineweb-edu-dedup-10b",
    dataset_split="train[:50000]",
    dataset_column="text",
    batch_size=8,
    cache_ctx_len=256,
    n_splits=5,
    n_tokens=5_000_000,
)
sampler_cfg = SamplerConfig(
    train_type="quantiles",
    test_type="quantiles",
    n_examples_train=30,
    n_examples_test=20,
    n_quantiles=10,
)
constructor_cfg = ConstructorConfig(
    min_examples=50,
    example_ctx_len=32,
    n_non_activating=20,
    non_activating_source="random",
    faiss_embedding_cache_enabled=True,
    faiss_embedding_cache_dir=".embedding_cache",
)
run_cfg = RunConfig(
    name="fineweb",
    overwrite=["cache", "scores"],
    model="Qwen/Qwen2.5-0.5B",
    sparse_model="EleutherAI/sae-pythia-160m-32k",
    hookpoints=["layers.17"],
    explainer_model="Qwen/Qwen2.5-7B-Instruct",  # "meta-llama/Llama-3.2-3B-Instruct",
    # explainer_model_max_len=4096,
    # max_latents=5000,
    seed=21,
    num_gpus=torch.cuda.device_count(),
    filter_bos=True,
    verbose=True,
    sampler_cfg=sampler_cfg,
    constructor_cfg=constructor_cfg,
    cache_cfg=cache_cfg,
)

start_time = time.time()
await run(run_cfg)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

scores_path = Path.cwd() / "results" / run_cfg.name / "scores"

latent_df, _ = load_data(scores_path, run_cfg.hookpoints)
processed_df = get_metrics(latent_df)
# Performs better than random guessing
for score_type, df in processed_df.groupby("score_type"):
    accuracy = df["accuracy"].mean()
    assert accuracy > 0.55, f"Score type {score_type} has an accuracy of {accuracy}"

In [6]:
processed_df

Unnamed: 0,score_type,true_positives,true_negatives,false_positives,false_negatives,total_examples,total_positives,total_negatives,failed_count,precision,recall,f1_score,accuracy,true_positive_rate,true_negative_rate,false_positive_rate,false_negative_rate,positive_class_ratio,negative_class_ratio,auc
0,detection,47616,76029,23788,52507,199940,100123,99817,0,0.666853,0.475575,0.555201,0.618629,0.475575,0.761684,0.238316,0.524425,0.500765,0.499235,


In [12]:
latent_df.to_csv("results/fineweb/perf.csv", index=False)

In [8]:
df = latent_df.copy().dropna(subset=["activating", "prediction"])
df.activating = df.activating.astype(int)
df.prediction = df.prediction.astype(int)

In [9]:
from sklearn.metrics import precision_score, recall_score, f1_score
import pandas as pd

def prf(group):
    return pd.Series(
        {
            "precision": precision_score(
                group["activating"],
                group["prediction"],
                zero_division=0,  # or 1, depending on how you want to handle all-zero cases
            ),
            "recall": recall_score(
                group["activating"], group["prediction"], zero_division=0
            ),
            "f1": f1_score(group["activating"], group["prediction"], zero_division=0),
        }
    )


metrics_by_pk = (
    df[["activating", "prediction", "latent_idx"]]
    .dropna(subset=["activating", "prediction"], axis=0)
    .groupby("latent_idx")
    .apply(prf)
    .reset_index()
)





In [10]:
metrics_by_pk.sort_values("f1")

Unnamed: 0,latent_idx,precision,recall,f1
4657,30411,0.000000,0.00,0.000000
3549,23055,0.000000,0.00,0.000000
1238,8033,0.000000,0.00,0.000000
2369,15183,0.000000,0.00,0.000000
3686,23931,0.000000,0.00,0.000000
...,...,...,...,...
1026,6717,0.950000,0.95,0.950000
222,1581,0.950000,0.95,0.950000
4206,27368,1.000000,0.95,0.974359
1773,11422,0.952381,1.00,0.975610
