In [1]:
import os
import gc
import sys
import json
import torch
from torch import nn

import numpy as np
import pandas as pd
from tqdm import tqdm

from IPython.display import HTML, display
import matplotlib
import matplotlib.pyplot as plt
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go


from transformer_lens import HookedTransformer
import transformer_lens.utils as tutils

from model import MLP, AutoEncoder
from utils import get_freqs, find_similar_decoder_weights, find_similar_encoder_weights
from buffer import Buffer
from config import Config

from datasets import load_dataset

from make_json import make_json, single_neuron


In [2]:
device = 'cpu'

save_dir = "mlps/2023-11-17_13-08-07" # driven bird.
# save_dir = "mlps/2023-11-19_17-16-13" # clean flower

config = Config.from_json(file_path=os.path.join(save_dir, "cfg.json"))
config.device = 'mps'
print(config)

original_model = HookedTransformer.from_pretrained(config.original_model)
original_model.to(device)

Config(save_dir='mlps/2023-11-17_13-08-07', device='mps', original_model='tiny-stories-2L-33M', dataset='roneneldan/TinyStories', batch_size=1024, buffer_mult=384, num_tokens=3000000000, seq_len=128, layer_idx=1, in_hook='ln2.hook_normalized', out_hook='hook_mlp_out', lr=5e-05, l1_coeff=0.0003, l1_warmup=None, l1_sqrt=False, beta1=0.9, beta2=0.99, weight_decay=0.001, d_hidden_mult=16, d_in=1024, act='gelu', leq_renorm=True, per_neuron_coeff=False, model_batch_size=128, buffer_size=393216, buffer_batches=3072)
Loaded pretrained model tiny-stories-2L-33M into HookedTransformer
Moving model to device:  cpu


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hoo

In [3]:
model_path = os.path.join(save_dir, "mlp_final.pt")

model = MLP(cfg=config)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
model.to(device)


MLP(
  (act): GELU(approximate='none')
)

In [4]:
# load autoencoder
ae_dir = "autoencoders/clean_armadillo"
ae_config = json.load(open(os.path.join(ae_dir, "cfg.json"), "r"))
ae_model_path = os.path.join(ae_dir, "autoencoder.pt")
ae = AutoEncoder(
    d_hidden=ae_config["d_mlp"] * ae_config["expansion_factor"],
    l1_coeff=ae_config["l1_coeff"],
    d_in=ae_config["d_mlp"],
)
ae.load_state_dict(torch.load(ae_model_path, map_location=device))
ae.to(device)

AutoEncoder()

In [8]:
A = model.W_dec.detach().cpu()
B = ae.W_dec @ original_model.blocks[1].mlp.W_out.detach().cpu()

# Normalize A and B
A = torch.nn.functional.normalize(A, p=2, dim=1)
B = torch.nn.functional.normalize(B, p=2, dim=1)

# Compute cosine similarity
cosine_similarity = A @ B.T  # Shape: (16384, 16384)

# Find the indices of the max values along each row
max_similarities, most_similar_idxs = torch.max(cosine_similarity, dim=1)



In [10]:
px.histogram(max_similarities.detach().flatten(), title="Similarity", histnorm="percent")