In [1]:
import os
import gc
import sys
import json
import torch
from torch import nn

import numpy as np
import pandas as pd
from tqdm import tqdm

from IPython.display import HTML, display
import matplotlib
import matplotlib.pyplot as plt
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go


from transformer_lens import HookedTransformer
import transformer_lens.utils as tutils

from model import MLP, AutoEncoder
from utils import get_freqs, find_similar_decoder_weights, find_similar_encoder_weights
from buffer import Buffer
from config import Config

from datasets import load_dataset

from make_json import make_json, single_neuron


In [2]:
save_dir = "mlps/2023-11-16_19-19-53" # eg

config = Config.from_json(file_path=os.path.join(save_dir, "cfg.json"))
config.device = 'mps'
print(config)

original_model = HookedTransformer.from_pretrained(config.original_model)
original_model.to(config.device)

Config(save_dir='mlps/2023-11-16_19-19-53', device='mps', original_model='gelu-1l', dataset='NeelNanda/c4-code-tokenized-2b', batch_size=2048, buffer_mult=384, num_tokens=3000000000, seq_len=128, layer_idx=0, in_hook='ln2.hook_normalized', out_hook='hook_mlp_out', lr=0.0001, l1_coeff=0.0003, l1_warmup=None, l1_sqrt=False, beta1=0.9, beta2=0.99, weight_decay=0.001, d_hidden_mult=32, d_in=512, act='gelu', leq_renorm=True, per_neuron_coeff=False, model_batch_size=256, buffer_size=786432, buffer_batches=6144)
Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  mps


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [3]:
data = load_dataset("roneneldan/TinyStories", split="validation")
tokenized_data = tutils.tokenize_and_concatenate(data, original_model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

Repo card metadata block was not found. Setting CardData to empty.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  table = cls._concat_blocks(blocks, axis=0)


In [4]:
buffer = Buffer(cfg=config, model=original_model, device=config.device, return_mid=True)

model_path = os.path.join(save_dir, "mlp_final.pt")

model = MLP(cfg=config)
model.load_state_dict(torch.load(model_path, map_location=config.device))
model.eval()
model.to(config.device)

freqs = get_freqs(original_model=original_model,
                  local_encoder=model,
                  all_tokens=buffer.all_tokens,
                  cfg=config,
                  num_batches=50,
                  )

Shuffling the data
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2048]) torch.Size([32768, 2048])
self.mid_h.shape torch.Size([32768, 2048])
torch.Size([786432, 2

100%|█████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:07<00:00,  6.97it/s]


In [6]:
# load autoencoder
ae_dir = "autoencoders/clean_armadillo"
ae_config = json.load(open(os.path.join(ae_dir, "cfg.json"), "r"))
ae_model_path = os.path.join(ae_dir, "autoencoder.pt")
ae = AutoEncoder(
    d_hidden=ae_config["d_mlp"] * ae_config["expansion_factor"],
    l1_coeff=ae_config["l1_coeff"],
    d_in=ae_config["d_mlp"],
)
ae.load_state_dict(torch.load(ae_model_path, map_location=config.device))

<All keys matched successfully>