In [1]:
import os
import gc
import sys
import json
import torch
from torch import nn

import numpy as np
import pandas as pd
from tqdm import tqdm

from IPython.display import HTML, display
import matplotlib
import matplotlib.pyplot as plt
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go


from transformer_lens import HookedTransformer
import transformer_lens.utils as tutils

from model import MLP, AutoEncoder
from utils import get_freqs, find_similar_decoder_weights, find_similar_encoder_weights
from buffer import Buffer
from config import Config

from datasets import load_dataset

from make_json import make_json, single_neuron


In [2]:
# save_dir = "mlps/2023-11-17_13-08-07" # driven bird.
# save_dir = "mlps/2023-11-19_17-16-13" # clean flower
# save_dir = "mlps/2023-11-21_09-10-40" # laced durian. Tiny stories l1_coeff 0.0002.
# save_dir = "mlps/2023-11-21_23-00-47" # morning universe. Tiny stories l1_coeff 0.00001. tiny.
save_dir = "mlps/2023-11-25_11-14-39" # lemon dragon. Tiny stories. ReLU.
save_dir = "mlps/2023-11-26_17-53-28" # sunny hill. ReLU. l1 0.0002. <-- this is good.

config = Config.from_json(file_path=os.path.join(save_dir, "cfg.json"))
config.device = 'mps'
print(config)

original_model = HookedTransformer.from_pretrained(config.original_model)
original_model.to(config.device)

Config(save_dir='mlps/2023-11-26_17-53-28', device='mps', original_model='tiny-stories-2L-33M', dataset='roneneldan/TinyStories', batch_size=1024, buffer_mult=384, num_tokens=4000000000, seq_len=128, layer_idx=1, in_hook='ln2.hook_normalized', out_hook='hook_mlp_out', lr=5e-05, l1_coeff=0.0002, l1_warmup=None, l1_sqrt=False, beta1=0.9, beta2=0.99, weight_decay=0.0001, d_hidden_mult=16, d_in=1024, act='relu', leq_renorm=True, per_neuron_coeff=False, model_batch_size=128, buffer_size=393216, buffer_batches=3072)
Loaded pretrained model tiny-stories-2L-33M into HookedTransformer
Moving model to device:  mps


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hoo

In [3]:
data = load_dataset("roneneldan/TinyStories", split="validation")
tokenized_data = tutils.tokenize_and_concatenate(data, original_model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

Repo card metadata block was not found. Setting CardData to empty.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  table = cls._concat_blocks(blocks, axis=0)


In [4]:
buffer = Buffer(cfg=config, model=original_model, device=config.device, return_mid=True)

model_path = os.path.join(save_dir, "mlp_final.pt")

model = MLP(cfg=config)
model.load_state_dict(torch.load(model_path, map_location=config.device))
model.eval()
model.to(config.device)

gmlp_freqs = get_freqs(original_model=original_model,
                  local_encoder=model,
                  all_tokens=buffer.all_tokens,
                    batch_size=config.model_batch_size,
                    layer_idx=config.layer_idx,
                    in_hook=config.in_hook,
                    d_in=config.d_in,
                    device=config.device,
                  num_batches=50,
                  )

Shuffling the data
Buffer initialised


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:11<00:00,  4.48it/s]


In [5]:
gmlp_freqs = gmlp_freqs.cpu().numpy()

print(gmlp_freqs.max(), gmlp_freqs.min(), gmlp_freqs.mean(), gmlp_freqs.std())

gmlp_freqs = np.log10(gmlp_freqs + 10**-6.5)

px.histogram(gmlp_freqs, title="Log Frequency of Features", histnorm='percent', nbins=100)

0.4487439 0.0 0.037707705 0.04513597


In [6]:
# load autoencoder
ae_dir = "autoencoders/clean_armadillo"
ae_config = json.load(open(os.path.join(ae_dir, "cfg.json"), "r"))
ae_model_path = os.path.join(ae_dir, "autoencoder.pt")
ae = AutoEncoder(
    d_hidden=ae_config["d_mlp"] * ae_config["expansion_factor"],
    l1_coeff=ae_config["l1_coeff"],
    d_in=ae_config["d_mlp"],
)
ae.load_state_dict(torch.load(ae_model_path, map_location=config.device))
ae.to(config.device)

AutoEncoder()

In [7]:
ae_freqs = get_freqs(original_model=original_model,
                    local_encoder=ae,
                    all_tokens=buffer.all_tokens,
                    batch_size=config.model_batch_size,
                    layer_idx=config.layer_idx,
                    in_hook="mlp.hook_post",
                    d_in=ae_config["d_mlp"],
                    device=config.device,
                    num_batches=50,
                    )

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.75it/s]


In [8]:
ae_freqs = ae_freqs.cpu().numpy()

print(ae_freqs.max(), ae_freqs.min(), ae_freqs.mean(), ae_freqs.std())

log_ae_freqs = np.log10(ae_freqs + 10**-6.5)

px.histogram(log_ae_freqs, title="Log Frequency of Features", histnorm='percent', nbins=100)

0.42307863 0.0 0.004675967 0.008898296


In [9]:
ae_freqs.shape

(16384,)

In [10]:
# compare GMLP encoder to original MLP encoder.
A = model.W_dec.detach().cpu() 
B = original_model.blocks[1].mlp.W_out.detach().cpu() 

A = torch.nn.functional.normalize(A, p=2, dim=1)
B = torch.nn.functional.normalize(B, p=2, dim=1)

# Compute cosine similarity for the original A
cosine_similarity = A @ B.T  # Shape: (16384, 4096)
max_similarities, most_similar_idxs = torch.max(cosine_similarity, dim=1)
hist_data_all = max_similarities.detach().flatten().numpy()

# find the features which are similar to neurons.
decoder_idxs_similar_to_orig = (max_similarities > 0.8).nonzero().squeeze(1)

px.histogram(hist_data_all, title="Decoder Features similarity to original", histnorm='percent')

In [11]:
A = model.W_dec.detach().cpu()
B = ae.W_dec.detach().cpu() @ original_model.blocks[1].mlp.W_out.detach().cpu()

# Normalize A and B
A = torch.nn.functional.normalize(A, p=2, dim=1)
B = torch.nn.functional.normalize(B, p=2, dim=1)

# Compute cosine similarity
cosine_similarity = A @ B.T  # Shape: (16384, 16384)

# Find the indices of the max values along each row
max_similarities, most_similar_idxs = torch.max(cosine_similarity, dim=1)

# px.histogram(max_similarities.detach().flatten(), title="Similarity", histnorm="percent")

# Extract max similarities for selected indices
selected_max_similarities = max_similarities[decoder_idxs_similar_to_orig]

# Flatten and convert to numpy arrays for plotting
max_similarities = max_similarities.detach().flatten().numpy()
selected_max_similarities = selected_max_similarities.detach().flatten().numpy()

# Create labels for the histogram
all_labels = ["all"] * len(max_similarities) + ["close_to_neurons"] * len(selected_max_similarities)

# Concatenate the max similarities
all_max_similarities = np.concatenate([max_similarities, selected_max_similarities])

# Create a DataFrame for plotting
df = pd.DataFrame({
    "Max Similarity": all_max_similarities,
    "Category": all_labels
})

# Plot the histogram
fig = px.histogram(df, x="Max Similarity", color="Category", barmode='overlay',
                   title="Similarity of GMLP features to closest AE feature")
fig.show()


In [13]:
# sorted_sims, sorted_idxs = torch.sort(max_similarities, descending=True)
# sorted_matches = most_similar_idxs[sorted_idxs]
# print(sorted_sims[:10])
# print(sorted_idxs[:10])
# print(sorted_matches[:10])

In [27]:
# feature_id = 4


batch_size = 256
n_batches = 20

tokens = all_tokens[:batch_size*n_batches]
# all_acts = []
# for i in tqdm(range(n_batches)):
#     batch = tokens[i*batch_size:(i+1)*batch_size]
#     in_hook = config.in_hook
#     _, cache = original_model.run_with_cache(batch,
#                                              stop_at_layer=config.layer_idx + 1,
#                                              names_filter=f"blocks.{config.layer_idx}.{in_hook}"
#                                             )
#     mlp_acts = cache[f"blocks.{config.layer_idx}.{in_hook}"]
#     mlp_acts = mlp_acts.reshape(-1, config.d_in)
#     hidden_acts = model.encode(mlp_acts)

#     hidden_acts = hidden_acts[:, feature_id].detach().cpu()
#     all_acts.append(hidden_acts)
    
# all_acts = torch.cat(all_acts, dim=0)
# all_acts.shape

In [28]:
# act_data = single_neuron(tokens=tokens, acts=all_acts.reshape(-1, 128), ft_id=feature_id, original_model=original_model)

# def style_snippet(snippet_idx):
#     tokens_with_activations = act_data['snippets'][snippet_idx]["token_activation_pairs"]
#     max_act = act_data['snippets'][snippet_idx]["max_activation"]
    
#     # Function to map activation to color
#     def activation_to_color(activation):
#         if activation < 0:
#             return '#FFFFFF'
#         normalized_activation = activation / max_act*0.6
#         return plt.cm.Reds(normalized_activation)
    
#     styled_text = ''.join(f'<span style="background-color: {matplotlib.colors.rgb2hex(activation_to_color(activation))}; margin-right: 0px;">{token}</span>'
#                           for token, activation in tokens_with_activations)
#     return styled_text

# print("feature id: ", feature_id)
# print()
# # Display each snippet with its number and max activation
# for i, snippet in enumerate(act_data['snippets']):
#     styled_text = style_snippet(i)
#     # snippet_info = f'<div style="margin-bottom: 10px;"><strong>Snippet number:</strong> {i}<br><strong>Max activation:</strong> {snippet["max_activation"]}<br>{styled_text}</div>'
#     # display(HTML(snippet_info))
#     snippet_info = f'<div style="word-wrap: break-word; margin-bottom: 10px;"><strong>Snippet number:</strong> {i}<br><strong>Max activation:</strong> {snippet["max_activation"]}<br>{styled_text}</div>'
#     display(HTML(snippet_info))

In [33]:
# look at top activations for AE

feature_id = 5

all_acts = []
for i in tqdm(range(n_batches)):
    batch = tokens[i*batch_size:(i+1)*batch_size]
    in_hook = "mlp.hook_post" ### <--- Specific for AE.
    _, cache = original_model.run_with_cache(batch,
                                             stop_at_layer=config.layer_idx + 1,
                                             names_filter=f"blocks.{config.layer_idx}.{in_hook}"
                                            )
    mlp_acts = cache[f"blocks.{config.layer_idx}.{in_hook}"]
    mlp_acts = mlp_acts.reshape(-1, config.d_in)
    hidden_acts = model.encode(mlp_acts)

    hidden_acts = hidden_acts[:, feature_id].detach().cpu()
    all_acts.append(hidden_acts)
    
all_acts = torch.cat(all_acts, dim=0)
all_acts

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:18<00:00,  1.10it/s]


tensor([0.0000, 0.2467, 0.0000,  ..., 0.0000, 0.2367, 0.0000])

In [36]:
all_acts.shape

torch.Size([2621440])

In [35]:
act_data = single_neuron(tokens=tokens, acts=all_acts.reshape(-1, 128), ft_id=feature_id, original_model=original_model)

def style_snippet(snippet_idx):
    tokens_with_activations = act_data['snippets'][snippet_idx]["token_activation_pairs"]
    max_act = act_data['snippets'][snippet_idx]["max_activation"]
    
    # Function to map activation to color
    def activation_to_color(activation):
        if activation < 0:
            return '#FFFFFF'
        normalized_activation = activation / max_act*0.6
        return plt.cm.Reds(normalized_activation)
    
    styled_text = ''.join(f'<span style="background-color: {matplotlib.colors.rgb2hex(activation_to_color(activation))}; margin-right: 0px;">{token}</span>'
                          for token, activation in tokens_with_activations)
    return styled_text

print("feature id: ", feature_id)
print()
# Display each snippet with its number and max activation
for i, snippet in enumerate(act_data['snippets']):
    styled_text = style_snippet(i)
    # snippet_info = f'<div style="margin-bottom: 10px;"><strong>Snippet number:</strong> {i}<br><strong>Max activation:</strong> {snippet["max_activation"]}<br>{styled_text}</div>'
    # display(HTML(snippet_info))
    snippet_info = f'<div style="word-wrap: break-word; margin-bottom: 10px;"><strong>Snippet number:</strong> {i}<br><strong>Max activation:</strong> {snippet["max_activation"]}<br>{styled_text}</div>'
    display(HTML(snippet_info))

IndexError: index 11874 is out of bounds for dimension 0 with size 5120