# Setup

In [2]:
from dictionary_learning import CrossCoder
from torch.nn.functional import cosine_similarity
import torch as th
import plotly.express as px
from pathlib import Path
from tqdm.notebook import tqdm

th.set_grad_enabled(False)
exp_name = "eval_crosscoder"

In [3]:
crosscoder_path = "/dlabscratch1/jminder/repositories/representation-structure-comparison/checkpoints/trainer_0/l13-mu1e-02/checkpoints/ae_50000.pt"
extra_args = []
exp_id = "test"
device = "cuda"
seed = 42
base_model = "gemma-2-2b"
instruct_model = "gemma-2-2b-it"
layer = 13
activation_dir = Path(
    "/dlabscratch1/jminder/repositories/representation-structure-comparison/activations"
)
validation_size = 10**6
batch_size = 2048
workers = 12

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--is-legacy", action="store_true", default=False)
args = parser.parse_args(extra_args)
print(args)
is_legacy = args.is_legacy

if is_legacy:
    state_dict = th.load(crosscoder_path, map_location="cpu", weights_only=False)

    fixed_state_dict = {k.split("_orig_mod.")[1]: v for k, v in state_dict.items()}

    num_layers, activation_dim, dict_size = fixed_state_dict["encoder.weight"].shape
    fixed_state_dict["encoder.bias"] = fixed_state_dict["encoder.bias"].sum(dim=0)

    coder = CrossCoder(
        activation_dim,
        dict_size,
        num_layers,
    )
    coder.load_state_dict(fixed_state_dict)
else:
    coder = CrossCoder.from_pretrained(crosscoder_path)
    num_layers, activation_dim, dict_size = coder.encoder.weight.shape
print(coder)

save_path = (
    Path("./results")
    / exp_name
    / ("_".join(crosscoder_path.split("/")[-3:]) + f"_{exp_id}")
)
save_path_extra = save_path / "extra"
save_path_extra.mkdir(exist_ok=True, parents=True)
save_path_html = save_path / "html"
save_path_html.mkdir(exist_ok=True, parents=False)
if device == "auto":
    device = th.device("cuda" if th.cuda.is_available() else "cpu")
else:
    device = th.device(device)

# Weight analysis

## Compare feature norms

In [None]:
norms = coder.decoder.weight.norm(dim=-1)
norm_diffs = ((norms[0] - norms[1]) / norms.max(dim=0).values + 1) / 2
sorted_norm_diffs = norm_diffs.sort(descending=True)

# fig = px.line(y=sorted_norm_diffs.values, title="Relative difference in decoder feature norms")
fig = px.histogram(x=sorted_norm_diffs.values, title="Relative difference in decoder feature norms", orientation='v', nbins=50)
fig.update_layout(
    annotations=[
        dict(x=0, y=0, xref="x", yref="paper", text="  <b>(IT only)</b>", showarrow=False, yanchor="top", xanchor="left"),
        dict(x=0.5, y=0, xref="x", yref="paper", text="    <b>(shared)</b>", showarrow=False, yanchor="top", xanchor="left"),
        dict(x=1, y=0, xref="x", yref="paper", text="  <b>(Base only)</b>", showarrow=False, yanchor="top", xanchor="left")
    ]
)
fig.update_xaxes(tickvals=[0, 0.5, 1])
# fig.update_traces(hovertemplate='Feature Index: %{text}<br>Value: %{y}', text=sorted_norm_diffs.indices)
fig.update_xaxes(title="Sorted Features (Highest to Lowest Difference)")
fig.update_yaxes(title="Relative Norm Difference")
fig.show()
fig.write_html(save_path_html / "decoder_norm_diffs.html")
fig.write_image(save_path / "decoder_norm_diffs.png")


In [None]:
norms = coder.decoder.weight.norm(dim=-1)
norm_diffs = ((norms[0] - norms[1]) / norms.max(dim=0).values + 1) / 2
sorted_norm_diffs = norm_diffs.sort(descending=True)

fig = px.line(y=sorted_norm_diffs.values, title="Relative difference in decoder feature norms")
fig.update_layout(
    annotations=[
        dict(x=0, y=0, xref="x", yref="paper", text="<b>(IT only)</b>", showarrow=False, yanchor="middle", xanchor="left"),
        dict(x=0, y=0.5, xref="x", yref="paper", text="<b>(shared)</b>", showarrow=False, yanchor="middle", xanchor="left"),
        dict(x=0, y=1, xref="x", yref="paper", text="<b>(Base only)</b>", showarrow=False, yanchor="middle", xanchor="left")
    ]
)
fig.update_yaxes(tickvals=[0, 0.5, 1])
fig.update_traces(hovertemplate='Feature Index: %{text}<br>Value: %{y}', text=sorted_norm_diffs.indices)
fig.update_xaxes(title="Sorted Features (Highest to Lowest Difference)")
fig.update_yaxes(title="Relative Norm Difference")
fig.show()
fig.write_html(save_path_html / "decoder_norm_diffs_line.html")
fig.write_image(save_path_extra / "decoder_norm_diffs_line.png")


In [None]:
enc_norms = coder.encoder.weight.norm(dim=1)
enc_norm_diffs = ((enc_norms[0] - enc_norms[1]) / enc_norms.max(dim=0).values + 1) / 2

fig = px.histogram(
    x=enc_norm_diffs,
    title="Relative difference in encoder feature norms",
    orientation="v",
    nbins=50,
)
fig.update_xaxes(title="Relative Norm Difference")
fig.update_yaxes(title="Count")
fig.update_layout(
    annotations=[
        dict(
            x=0,
            y=0,
            xref="x",
            yref="paper",
            text="<b>(IT only)</b>",
            showarrow=False,
            yanchor="top",
            xanchor="left",
        ),
        dict(
            x=0.5,
            y=0,
            xref="x",
            yref="paper",
            text="<b>(shared)</b>",
            showarrow=False,
            yanchor="top",
            xanchor="left",
        ),
        dict(
            x=1,
            y=0,
            xref="x",
            yref="paper",
            text="<b>(Base only)</b>",
            showarrow=False,
            yanchor="top",
            xanchor="left",
        ),
    ]
)
fig.show()
fig.write_html(save_path_html / "encoder_norm_diffs.html")
fig.write_image(save_path / "encoder_norm_diffs.png")

In [None]:

# enc_sorted_norm_diffs = enc_norm_diffs.sort(descending=True)
indexed_enc_norm_diffs = enc_norm_diffs[sorted_norm_diffs.indices]

fig = px.line(y=indexed_enc_norm_diffs, title="Relative difference in encoder feature norms, sorted by decoder feature norm difference")
fig.update_traces(hovertemplate='Feature Index: %{text}<br>Value: %{y}', text=sorted_norm_diffs.indices)
fig.update_xaxes(title="Sorted Features (Highest to Lowest Difference)")
fig.update_yaxes(title="Relative Norm Difference")
fig.show()
fig.write_html(save_path_html / "encoder_norm_diffs_line_sorted_by_decoder.html")
fig.write_image(save_path_extra / "encoder_norm_diffs_line_sorted_by_decoder.png")


In [None]:
enc_sorted_norm_diffs = enc_norm_diffs.sort(descending=True)

fig = px.line(y=enc_sorted_norm_diffs.values, title="Relative difference in encoder feature norms")
fig.update_traces(hovertemplate='Feature Index: %{text}<br>Value: %{y}', text=enc_sorted_norm_diffs.indices)
fig.update_xaxes(title="Sorted Features (Highest to Lowest Difference)")
fig.update_yaxes(title="Relative Norm Difference")
fig.update_layout(
    annotations=[
        dict(x=0, y=0, xref="x", yref="paper", text="<b>(IT only)</b>", showarrow=False, yanchor="top", xanchor="left"),
        dict(x=0, y=0.5, xref="x", yref="paper", text="<b>(shared)</b>", showarrow=False, yanchor="top", xanchor="left"),
        dict(x=0, y=1, xref="x", yref="paper", text="<b>(Base only)</b>", showarrow=False, yanchor="top", xanchor="left"),
    ]
)
fig.show()
fig.write_html(save_path_html / "encoder_norm_diffs_line.html")
fig.write_image(save_path_extra / "encoder_norm_diffs_line.png")


## Cosine similarity

### Decoder

In [None]:
decoder_cos_sims = cosine_similarity(coder.decoder.weight[0], coder.decoder.weight[1], dim=1)
print(decoder_cos_sims.shape)

decoder_cos_sims_sorted = decoder_cos_sims.sort(descending=True)

# Calculate mean cosine similarity between random latent vectors
num_samples = 10000
random_idx = th.randint(0, coder.decoder.weight.shape[1], (num_samples, 2))
random_cos_sims = cosine_similarity(coder.decoder.weight[0][random_idx[:, 0]], coder.decoder.weight[1][random_idx[:, 1]], dim=1)
mean_random_cos_sim = random_cos_sims.mean().item()

In [None]:
fig = px.histogram(x=decoder_cos_sims, title="Cosine similarity between decoder feature vectors", orientation="v", nbins=50)
fig.update_xaxes(title="Cosine Similarity")
fig.update_yaxes(title="Count")
fig.add_vline(x=mean_random_cos_sim, line_dash="dash", line_color="red", annotation_text=f"Mean Random Cosine Similarity: {mean_random_cos_sim:.4f}")
fig.show()
fig.write_html(save_path_html / "decoder_cos_sims.html")
fig.write_image(save_path / "decoder_cos_sims.png")

In [None]:
fig = px.line(y=decoder_cos_sims_sorted.values, title="Cosine similarity between decoder feature vectors")
fig.update_traces(hovertemplate='Feature Index: %{text}<br>Value: %{y}', text=decoder_cos_sims_sorted.indices)
fig.update_xaxes(title="Sorted Features (Highest to Lowest Similarity)")
fig.update_yaxes(title="Cosine Similarity")

# Add horizontal line for mean random cosine similarity
fig.add_hline(y=mean_random_cos_sim, line_dash="dash", line_color="red", annotation_text=f"Mean Random Cosine Similarity: {mean_random_cos_sim:.4f}")

fig.show()
fig.write_html(save_path_html / "decoder_cos_sims_line.html")
fig.write_image(save_path_extra / "decoder_cos_sims_line.png")


In [32]:
encoder_cos_sims = cosine_similarity(coder.encoder.weight[0], coder.encoder.weight[1], dim=0)
mean_random_cos_sim_encoder = cosine_similarity(coder.encoder.weight[0][:, random_idx[:, 0]], coder.encoder.weight[1][:, random_idx[:, 1]], dim=0).mean().item()

In [None]:
fig = px.histogram(x=encoder_cos_sims, title="Cosine similarity between encoder feature vectors", orientation="v", nbins=50)
fig.update_xaxes(title="Cosine Similarity")
fig.update_yaxes(title="Count")
fig.add_vline(x=mean_random_cos_sim_encoder, line_dash="dash", line_color="red", annotation_text=f"Mean Random Cosine Similarity: {mean_random_cos_sim_encoder:.4f}")
fig.show()
fig.write_html(save_path_html / "encoder_cos_sims.html")
fig.write_image(save_path / "encoder_cos_sims.png")


In [None]:
print(encoder_cos_sims.shape)
fig = px.line(y=encoder_cos_sims[decoder_cos_sims_sorted.indices], title="Cosine similarity between encoder feature vectors, sorted by decoder feature similarity")
fig.update_traces(hovertemplate='Feature Index: %{text}<br>Value: %{y}', text=decoder_cos_sims_sorted.indices)
fig.update_xaxes(title="Sorted Features (Highest to Lowest Similarity)")
fig.update_yaxes(title="Cosine Similarity")

# Add horizontal line for mean random cosine similarity
fig.add_hline(y=mean_random_cos_sim, line_dash="dash", line_color="red", annotation_text=f"Mean Random Cosine Similarity: {mean_random_cos_sim:.4f}")

fig.show()
fig.write_html(save_path_html / "encoder_cos_sims_sorted_by_decoder_line.html")
fig.write_image(save_path_extra / "encoder_cos_sims_sorted_by_decoder_line.png")


In [None]:
encoder_cos_sims_sorted = encoder_cos_sims.sort(descending=True)

# Calculate mean cosine similarity between random latent vectors
num_samples = 10000
random_latents = th.randn(num_samples, coder.encoder.weight.shape[1])
random_cos_sims = cosine_similarity(random_latents[:-1], random_latents[1:], dim=1)
mean_random_cos_sim = random_cos_sims.mean().item()

fig = px.line(y=encoder_cos_sims_sorted.values, title="Cosine similarity between encoder feature vectors")
fig.update_traces(hovertemplate='Feature Index: %{text}<br>Value: %{y}', text=encoder_cos_sims_sorted.indices)
fig.update_xaxes(title="Sorted Features (Highest to Lowest Similarity)")
fig.update_yaxes(title="Cosine Similarity")

# Add horizontal line for mean random cosine similarity
fig.add_hline(y=mean_random_cos_sim, line_dash="dash", line_color="red", annotation_text=f"Mean Random Cosine Similarity: {mean_random_cos_sim:.4f}")

fig.show()
fig.write_html(save_path_html / "encoder_cos_sims_line.html")
fig.write_image(save_path_extra / "encoder_cos_sims_line.png")


# Latent analysis

In [None]:
from dictionary_learning.trainers.crosscoder import CrossCoderTrainer
from dictionary_learning.training import run_validation
from dictionary_learning.cache import PairedActivationCache

trainer = CrossCoderTrainer(
    num_layers=num_layers,
    activation_dim=activation_dim,
    dict_size=dict_size,
    device=device,
)
trainer.ae = coder

th.manual_seed(seed)
th.cuda.manual_seed_all(seed)
base_model_dir = activation_dir / args.base_model
instruct_model_dir = activation_dir / args.instruct_model
base_model_fineweb = base_model_dir / "fineweb"
base_model_lmsys_chat = base_model_dir / "lmsys_chat"
instruct_model_fineweb = instruct_model_dir / "fineweb"
instruct_model_lmsys_chat = instruct_model_dir / "lmsys_chat"
submodule_name = f"layer_{args.layer}_out"
fineweb_cache = PairedActivationCache(
    base_model_fineweb / submodule_name, instruct_model_fineweb / submodule_name
)
lmsys_chat_cache = PairedActivationCache(
    base_model_lmsys_chat / submodule_name, instruct_model_lmsys_chat / submodule_name
)
dataset = th.utils.data.ConcatDataset([fineweb_cache, lmsys_chat_cache])
activation_dim = dataset[0].shape[1]

train_dataset, validation_dataset = th.utils.data.random_split(
    dataset, [len(dataset) - validation_size, validation_size]
)
dataloader = th.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    pin_memory=True,
)
validation_dataloader = th.utils.data.DataLoader(
    validation_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=True,
)

run_validation(trainer, validation_dataset, log_queues=[])