In [1]:
import http.server
import json
import os
import pprint
import socketserver
import threading
import webbrowser
from typing import NamedTuple, Optional, Union

import einops
import plotly.express as px
import torch
import torch.nn.functional as F
# from tiny_dashboard.dashboard_implementations import \
#     CrosscoderOnlineFeatureDashboard
from torch import nn
from transformer_lens import HookedTransformer

from crosscoder import CrossCoder
from utils import get_gsm8k_dataset

torch.set_grad_enabled(False)  # important for memory saving


device = 'cuda:0'


In [2]:
base_model = HookedTransformer.from_pretrained(
    "Qwen/Qwen2.5-1.5B",
    device=device,
    dtype=torch.bfloat16,
)

math_model = HookedTransformer.from_pretrained(
    "Qwen/Qwen2.5-Math-1.5B",
    device=device,
    dtype=torch.bfloat16,
)




Loaded pretrained model Qwen/Qwen2.5-1.5B into HookedTransformer




Loaded pretrained model Qwen/Qwen2.5-Math-1.5B into HookedTransformer


In [5]:
cross_coder = CrossCoder.load('version_2', 1)
collect_layer = 14

train_questions, train_answers = get_gsm8k_dataset(split='train')

{'batch_size': 100,
 'beta1': 0.9,
 'beta2': 0.999,
 'buffer_mult': 128,
 'd_in': 1536,
 'dec_init_norm': 0.08,
 'device': 'cuda:0',
 'dict_size': 16384,
 'enc_dtype': 'fp32',
 'hook_point': 'blocks.14.hook_resid_pre',
 'l1_coeff': 2,
 'log_every': 100,
 'lr': 5e-05,
 'model_batch_size': 4,
 'model_name': 'qwen2_5_1_5B',
 'num_tokens': 5000000,
 'save_every': 30000,
 'seed': 42,
 'seq_len': 1024,
 'site': 'resid_pre',
 'wandb_entity': 'binhnt',
 'wandb_project': 'crosscoder'}



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [6]:
# dataset
eos = base_model.tokenizer.special_tokens_map['eos_token']
train_questions, train_answers = get_gsm8k_dataset(split='train')
merged_prompts = [q + f" {eos} " + a for q, a in zip(train_questions, train_answers)]
all_tokens = base_model.tokenizer(merged_prompts, truncation=True, padding=False, max_length=512)  # for feature viz it's best to disable padding

In [7]:
all_tokens_padded = base_model.tokenizer(merged_prompts, truncation=True, padding=True, return_tensors='pt')

In [8]:
# Check histograms similar to Anthropic's paper
norms = cross_coder.W_dec.norm(dim=-1)
norms.shape

relative_norms = norms[:, 1] / norms.sum(dim=-1)
relative_norms.shape

fig = px.histogram(
    relative_norms.detach().cpu().numpy(),
    title="Qwen2.5-1.5B Base vs Math Model Diff",
    labels={"value": "Relative decoder norm strength"},
    nbins=200,
)

fig.update_layout(showlegend=False)
fig.update_yaxes(title_text="Number of Latents")

# Update x-axis ticks
fig.update_xaxes(
    tickvals=[0, 0.25, 0.5, 0.75, 1.0],
    ticktext=['0', '0.25', '0.5', '0.75', '1.0']
)

fig.show()


In [9]:
# Now let's check the cosine similarity of the "shared" decoder vectors between both models:

shared_latent_mask = (relative_norms < 0.7) & (relative_norms > 0.3)
shared_latent_mask.shape

torch.Size([16384])

In [10]:
cosine_sims = (cross_coder.W_dec[:, 0, :] * cross_coder.W_dec[:, 1, :]).sum(dim=-1) / (cross_coder.W_dec[:, 0, :].norm(dim=-1) * cross_coder.W_dec[:, 1, :].norm(dim=-1))
cosine_sims.shape

torch.Size([16384])

In [11]:
fig = px.histogram(
    cosine_sims[shared_latent_mask].to(torch.float32).detach().cpu().numpy(),
    #title="Cosine similarity of decoder vectors between models",
    log_y=True,  # Sets the y-axis to log scale
    range_x=[-1, 1],  # Sets the x-axis range from -1 to 1
    nbins=100,  # Adjust this value to change the number of bins
    labels={"value": "Cosine similarity of decoder vectors between models"}
)

fig.update_layout(showlegend=False)
fig.update_yaxes(title_text="Number of Latents (log scale)")

fig.show()

We notice reasonably high alignment, with a lot of high positive cosine sim.



When we trained our crosscoder, we normalized both the base and chat model activations such that they both have avg norm sqrt(d_model). In training, this is implemented by estimating scaling constants such that norm(scale * act) = sqrt(d_model) over a subset of the training distribution. I'll just hard code them in this demo.

This means we also need to normalize the activations during analysis. Further, since we'll be splicing the reconstructed activations back into the forward pass of the model, we need to "unscale" the reconstructed activations too. We can alternatively fold this into the weights, as below:


In [12]:
import copy
folded_cross_coder = copy.deepcopy(cross_coder)


def fold_activation_scaling_factor(cross_coder, base_scaling_factor, chat_scaling_factor):
    cross_coder.W_enc.data[0, :, :] = cross_coder.W_enc.data[0, :, :] * base_scaling_factor
    cross_coder.W_enc.data[1, :, :] = cross_coder.W_enc.data[1, :, :] * chat_scaling_factor

    cross_coder.W_dec.data[:, 0, :] = cross_coder.W_dec.data[:, 0, :] / base_scaling_factor
    cross_coder.W_dec.data[:, 1, :] = cross_coder.W_dec.data[:, 1, :] / chat_scaling_factor

    cross_coder.b_dec.data[0, :] = cross_coder.b_dec.data[0, :] / base_scaling_factor
    cross_coder.b_dec.data[1, :] = cross_coder.b_dec.data[1, :] / chat_scaling_factor
    return cross_coder


In [13]:
# Estimating normalizing factor
from buffer import Buffer
buff = Buffer(cross_coder.cfg, base_model, math_model, all_tokens_padded)

Estimating norm scaling factor: 100%|██████████| 100/100 [00:18<00:00,  5.50it/s]
Estimating norm scaling factor: 100%|██████████| 100/100 [00:17<00:00,  5.56it/s]


Refreshing the buffer!


In [14]:
base_estimated_scaling_factor, math_estimated_scaling_factor = buff.normalisation_factor.detach().cpu().numpy()
#del buff
print(base_estimated_scaling_factor, math_estimated_scaling_factor)

0.5624546 0.34796977


In [15]:
folded_cross_coder = fold_activation_scaling_factor(folded_cross_coder, base_estimated_scaling_factor, math_estimated_scaling_factor)
folded_cross_coder = folded_cross_coder.to(torch.bfloat16)

In [16]:
from functools import partial

def splice_act_hook(act, hook, spliced_act):
    act[:, 1:, :] = spliced_act # Drop BOS
    return act

def zero_ablation_hook(act, hook):
    act[:] = 0
    return act

def get_ce_recovered_metrics(tokens, model_A, model_B, cross_coder):
    # get clean loss
    ce_clean_A = model_A(tokens, return_type="loss")
    ce_clean_B = model_B(tokens, return_type="loss")

    # get zero abl loss
    ce_zero_abl_A = model_A.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks = [(cross_coder.cfg["hook_point"], zero_ablation_hook)],
    )
    ce_zero_abl_B = model_B.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks = [(cross_coder.cfg["hook_point"], zero_ablation_hook)],
    )

    # bunch of annoying set up for splicing
    _, cache_A = model_A.run_with_cache(
        tokens,
        names_filter=cross_coder.cfg["hook_point"],
        return_type=None,
        )
    resid_act_A = cache_A[cross_coder.cfg["hook_point"]]

    _, cache_B = model_B.run_with_cache(
        tokens,
        names_filter=cross_coder.cfg["hook_point"],
        return_type=None,
        )
    resid_act_B = cache_B[cross_coder.cfg["hook_point"]]

    cross_coder_input = torch.stack([resid_act_A, resid_act_B], dim=0)
    cross_coder_input = cross_coder_input[:, :, 1:, :] # Drop BOS
    cross_coder_input = einops.rearrange(
        cross_coder_input,
        "n_models batch seq_len d_model -> (batch seq_len) n_models d_model",
    )

    cross_coder_output = cross_coder.decode(cross_coder.encode(cross_coder_input))
    cross_coder_output = einops.rearrange(
        cross_coder_output,
        "(batch seq_len) n_models d_model -> n_models batch seq_len d_model", batch = tokens.shape[0]
    )
    cross_coder_output_A = cross_coder_output[0]
    cross_coder_output_B = cross_coder_output[1]

    # get spliced loss
    ce_loss_spliced_A = model_A.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks = [(cross_coder.cfg["hook_point"], partial(splice_act_hook, spliced_act=cross_coder_output_A))],
    )
    ce_loss_spliced_B = model_B.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks = [(cross_coder.cfg["hook_point"], partial(splice_act_hook, spliced_act=cross_coder_output_B))],
    )

    # compute % CE recovered metric
    ce_recovered_A = 1 - ((ce_loss_spliced_A - ce_clean_A) / (ce_zero_abl_A - ce_clean_A))
    ce_recovered_B = 1 - ((ce_loss_spliced_B - ce_clean_B) / (ce_zero_abl_B - ce_clean_B))

    metrics = {
        "ce_loss_spliced_A": ce_loss_spliced_A.item(),
        "ce_loss_spliced_B": ce_loss_spliced_B.item(),
        "ce_clean_A": ce_clean_A.item(),
        "ce_clean_B": ce_clean_B.item(),
        "ce_zero_abl_A": ce_zero_abl_A.item(),
        "ce_zero_abl_B": ce_zero_abl_B.item(),
        "ce_diff_A": (ce_loss_spliced_A - ce_clean_A).item(),
        "ce_diff_B": (ce_loss_spliced_B - ce_clean_B).item(),
        "ce_recovered_A": ce_recovered_A.item(),
        "ce_recovered_B": ce_recovered_B.item(),
    }
    return metrics

tokens = all_tokens_padded.input_ids[torch.randperm(len(all_tokens_padded))[:100]]
ce_metrics = get_ce_recovered_metrics(tokens, base_model, math_model, folded_cross_coder)

In [17]:
for (k, v) in ce_metrics.items():
    print(f"{k}: {v}")
del tokens

ce_loss_spliced_A: 11.75
ce_loss_spliced_B: 13.0625
ce_clean_A: 7.28125
ce_clean_B: 16.75
ce_zero_abl_A: 12.8125
ce_zero_abl_B: 11.5
ce_diff_A: 4.46875
ce_diff_B: -3.6875
ce_recovered_A: 0.19140625
ce_recovered_B: 0.296875


In [18]:
import copy
folded_cross_coder = copy.deepcopy(cross_coder)

def fold_activation_scaling_factor(cross_coder, base_scaling_factor, chat_scaling_factor):
    cross_coder.W_enc.data[0, :, :] = cross_coder.W_enc.data[0, :, :] * base_scaling_factor
    cross_coder.W_enc.data[1, :, :] = cross_coder.W_enc.data[1, :, :] * chat_scaling_factor

    # cross_coder.W_dec.data[:, 0, :] = cross_coder.W_dec.data[:, 0, :] / base_scaling_factor
    # cross_coder.W_dec.data[:, 1, :] = cross_coder.W_dec.data[:, 1, :] / chat_scaling_factor

    # cross_coder.b_dec.data[0, :] = cross_coder.b_dec.data[0, :] / base_scaling_factor
    # cross_coder.b_dec.data[1, :] = cross_coder.b_dec.data[1, :] / chat_scaling_factor
    return cross_coder

folded_cross_coder = fold_activation_scaling_factor(folded_cross_coder, base_estimated_scaling_factor, math_estimated_scaling_factor)

In [19]:
cross_coder.cfg

{'seed': 42,
 'batch_size': 100,
 'buffer_mult': 128,
 'lr': 5e-05,
 'num_tokens': 5000000,
 'l1_coeff': 2,
 'beta1': 0.9,
 'beta2': 0.999,
 'd_in': 1536,
 'dict_size': 16384,
 'seq_len': 1024,
 'enc_dtype': 'fp32',
 'model_name': 'qwen2_5_1_5B',
 'site': 'resid_pre',
 'device': 'cuda:0',
 'model_batch_size': 4,
 'log_every': 100,
 'save_every': 30000,
 'dec_init_norm': 0.08,
 'hook_point': 'blocks.14.hook_resid_pre',
 'wandb_project': 'crosscoder',
 'wandb_entity': 'binhnt'}

In [20]:
from sae_vis.model_fns import CrossCoderConfig, CrossCoder

encoder_cfg = CrossCoderConfig(d_in=base_model.cfg.d_model, d_hidden=cross_coder.cfg["dict_size"], apply_b_dec_to_input=False)
sae_vis_cross_coder = CrossCoder(encoder_cfg)
sae_vis_cross_coder.load_state_dict(folded_cross_coder.state_dict())
sae_vis_cross_coder = sae_vis_cross_coder.to("cuda:0")
sae_vis_cross_coder = sae_vis_cross_coder.to(torch.bfloat16)

from sae_vis.data_config_classes import SaeVisConfig
test_feature_idx = [2325,12698,1500]
sae_vis_config = SaeVisConfig(
    hook_point = folded_cross_coder.cfg["hook_point"],
    features = test_feature_idx,
    verbose = True,
    minibatch_size_tokens=4,
    minibatch_size_features=16,
)

In [21]:
from sae_vis.data_storing_fns import SaeVisData
sae_vis_data = SaeVisData.create(
    encoder = sae_vis_cross_coder,
    encoder_B = None,
    model_A = base_model,
    model_B = math_model,
    tokens = all_tokens_padded.input_ids[:128], # in practice, better to use more data
    cfg = sae_vis_config,
)

Forward passes to cache data for vis:   0%|          | 0/32 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
from IPython.display import IFrame

PORT = 8000

def display_vis_inline(filename: str, height: int = 850):
    """Serve visualization HTML file and display in notebook"""
    global PORT
    
    def serve(directory):
        global PORT
        os.chdir(directory)
        handler = http.server.SimpleHTTPRequestHandler
        
        try:
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving at http://localhost:{PORT}/{filename}")
                httpd.serve_forever()
        except OSError:
            PORT += 1
            serve(directory)

    thread = threading.Thread(target=serve, args=(os.getcwd(),), daemon=True)
    thread.start()
    
    # Instead of opening browser, display inline using IFrame
    return IFrame(src=f"http://localhost:{PORT}/{filename}", width='100%', height=height)

# Save and display visualization
filename = "_feature_vis_demo.html"
sae_vis_data.save_feature_centric_vis(filename)
display_vis_inline(filename)


Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

Serving at http://localhost:8000/_feature_vis_demo.html


127.0.0.1 - - [17/Feb/2025 17:02:00] "GET /_feature_vis_demo.html HTTP/1.1" 200 -
127.0.0.1 - - [17/Feb/2025 17:02:05] "GET /_feature_vis_demo.html HTTP/1.1" 304 -


In [23]:
from IPython.display import IFrame

# Use the same port that was used to serve the file
IFrame(src=f"http://localhost:{PORT}/_feature_vis_demo.html", width='100%', height=850)