In [1]:
from transformers import AutoModel, AutoTokenizer
from transformer_lens import HookedTransformer
from typing import Callable
import delphi
import torch
from models import CrossCoder
import torch.nn.functional as F
from delphi.config import RunConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_a_str = "Qwen/Qwen2.5-0.5B"
model_b_str = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer_a = AutoTokenizer.from_pretrained(model_a_str)
tokenizer_b = AutoTokenizer.from_pretrained(model_b_str)

modelA = HookedTransformer.from_pretrained(
    model_a_str, tokenizer=tokenizer_a, dtype="bfloat16"
)
modelB = HookedTransformer.from_pretrained(
    model_b_str, tokenizer=tokenizer_b, dtype="bfloat16"
)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loaded pretrained model Qwen/Qwen2.5-0.5B into HookedTransformer




Loaded pretrained model Qwen/Qwen2.5-0.5B-Instruct into HookedTransformer


In [3]:
coder = CrossCoder.load("9", modelA=modelA, modelB=modelB, path="models/some_model/version_1")

In [4]:
W_enc = coder.encoder.data
W_enc = W_enc.reshape(-1, 2, W_enc.shape[1])
b_enc = coder.encoder_bias.data

In [5]:
W_enc.shape

torch.Size([896, 2, 32768])

In [6]:
model_path = "models/some_model/version_1"
name = "9"
layer=18


def override_load_hooks_sparse_coders(model: AutoModel, cfg: RunConfig, compile: bool = False) -> tuple[dict[str, Callable], bool]:

    coder = CrossCoder.load(
        name, modelA=modelA, modelB=modelB, path=model_path
    )

    W_enc = coder.encoder.data
    W_enc = W_enc.reshape(-1, 2, W_enc.shape[1])
    b_enc = coder.encoder_bias.data

    def encode(x): 

        return F.relu(coder.topk_constraint(
            x @ W_enc + b_enc
        ))

    transcoder = False

    assert len(cfg.hookpoints) == 1, "only support one hook location"

    d = {
        cfg.hookpoints[0]: encode
    }

    return d, transcoder

In [7]:
from transformers import BitsAndBytesConfig

def load_artifacts(run_cfg: RunConfig):
    if run_cfg.load_in_8bit:
        dtype = torch.float16
    elif torch.cuda.is_bf16_supported():
        dtype = torch.bfloat16
    else:
        dtype = "auto"

    model = AutoModel.from_pretrained(
        run_cfg.model,
        device_map={"": "cuda"},
        quantization_config=(
            BitsAndBytesConfig(load_in_8bit=run_cfg.load_in_8bit)
            if run_cfg.load_in_8bit
            else None
        ),
        torch_dtype=dtype,
        token=run_cfg.hf_token,
    )

    hookpoint_to_sparse_encode, transcode = override_load_hooks_sparse_coders(
        model,
        run_cfg,
        compile=True,
    )

    return (
        list(hookpoint_to_sparse_encode.keys()),
        hookpoint_to_sparse_encode,
        model,
        transcode,
    )


In [8]:
from delphi.utils import assert_type
from pathlib import Path
from delphi.__main__ import non_redundant_hookpoints, create_neighbours, process_cache, populate_cache, log_results

async def run(
    run_cfg: RunConfig,
):
    base_path = Path.cwd() / "results"
    if run_cfg.name:
        base_path = base_path / run_cfg.name

    base_path.mkdir(parents=True, exist_ok=True)

    run_cfg.save_json(base_path / "run_config.json", indent=4)

    latents_path = base_path / "latents"
    explanations_path = base_path / "explanations"
    scores_path = base_path / "scores"
    neighbours_path = base_path / "neighbours"
    visualize_path = base_path / "visualize"

    latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None

    hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg)
    tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)

    nrh = assert_type(
        dict,
        non_redundant_hookpoints(
            hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite
        ),
    )
    if nrh:
        populate_cache(
            run_cfg,
            model,
            nrh,
            latents_path,
            tokenizer,
            transcode,
        )

    del model, hookpoint_to_sparse_encode
    if run_cfg.constructor_cfg.non_activating_source == "neighbours":
        nrh = assert_type(
            list,
            non_redundant_hookpoints(
                hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite
            ),
        )
        if nrh:
            create_neighbours(
                run_cfg,
                latents_path,
                neighbours_path,
                nrh,
            )
    else:
        print("Skipping neighbour creation")

    nrh = assert_type(
        list,
        non_redundant_hookpoints(
            hookpoints, scores_path, "scores" in run_cfg.overwrite
        ),
    )
    if nrh:
        await process_cache(
            run_cfg,
            latents_path,
            explanations_path,
            scores_path,
            nrh,
            tokenizer,
            latent_range,
        )

    if run_cfg.verbose:
        log_results(scores_path, visualize_path, run_cfg.hookpoints)


INFO 04-22 14:09:55 [__init__.py:239] Automatically detected platform cuda.


2025-04-22 14:09:59,220	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
