<a href="https://colab.research.google.com/github/zwimpee/1L-Sparse-Autoencoder/blob/main/dev.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# *Sparse Autoencoders, Feature Extraction, Dictionary Learning, Interpretability, and Supersymmetry* 🚀 😄

## Introduction
This notebook is to track progress on this personal project I am interested in, related to a peculiar similarity I think I may have noticed between 2 seemingly disparate worlds: $\textbf{Deep Learning}$ and $\textbf{Supersymmetry}$.

In particular, I have noticed fairly recently (within the last ~1 month) that the work being done in interpretibility over the last 3 - 4 years, like what is being done at [Anthropic](https://www.anthropic.com/research), particularly the work being done by [Chris Olah](https://scholar.google.com/citations?user=6dskOSUAAAAJ&hl=en) on [Transformer Circuits](https://transformer-circuits.pub/) bears a striking similarity to the work [I have been a part of](https://arxiv.org/abs/1906.02971) as part of my experience with the research group out of Brown led by [Dr. Sylvester James Gates (Jim)](https://twitter.com/dr_jimgates).

More specifically, I have caught on to a surprising thread (at least it is surprsising to me) that could suggest that the mathematics of supersymmetry and more generally quantum field theory could be applied directly to transformer-based neural networks (or perhaps neural networks in general) in order to both gain deeper insight into the mechanics of how a trained model is able to exhibit all of the emergent properties that we are observing as these models get larger, as well as to help more explicitly guide these models towards being truly aligned with human interests and priorities through some of the ablation/masking techniques studied in some of the works linked above.

I will expand more deeply on this connection later in this document, but for now I think it is more imperative to get a better sense of what the current formalism looks like for this emerging field of *mechanistic interpretibility*, so let's get started...

# 1. `Setup`

## 1.1 - `Setup` - Install Dependencies

In [81]:
!python3 -V

Python 3.10.12


In [82]:
!pip install transformer_lens
!pip install gradio
# !pip install tiktoken
# !pip install transformers
!pip install datasets



In [83]:
!pip install wandb -qU

In [None]:
# Log in to your W&B account
import wandb

In [None]:
wandb.login()

## 1.2 - `Setup` - Imports

In [97]:
import datasets
import gradio as gr
import json
import numpy as np
import plotly.express as px
import pprint
import pandas as pd
# import tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformer_lens

from dataclasses import dataclass, field, fields
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
from transformers import GPT2TokenizerFast
from tqdm.auto import tqdm
from transformer_lens import (
    # HookedTransformer,
    utils
)
from typing import Any, Dict


seed = 49
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

## 1.3 `Setup` - Config definition and initialization

In [99]:
@dataclass
class ModelConfig:
    name: str
    max_length: int
    batch_size: int
    learning_rate = 1e-4
    dtype: Any
    seed: Any

class TransformerConfig(ModelConfig):
    vocabulary_size: int
    num_layers: int
    num_heads: int
    embed_dim: int
    hidden_dim: int
    nonlinearity: Any


class AutoEncoderConfig(ModelConfig):
    batch_size: int
    buffer_mult: int
    num_tokens = int(2e9)
    l1_coeff: float
    beta1: float
    beta2: float
    dict_mult: float
    seq_len: int
    d_mlp: int
    enc_dtype: Any
    remove_rare_dir: bool


@dataclass
class DataConfig:
    name: str
    data: Any

@dataclass
class TrainingConfig:
    num_steps: int
    transformer_learning_rate: float
    autoencoder_learning_rate: float
    optimizer: Any
    batch_size: int
    device: Any


@dataclass
class Config:
    transformer: Any
    autoencoder: Any
    data: Any
    training: Any


# Model Configurations
# Transformer Configuration
model_name = "openai-community/gpt2"
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
batch_size = 1
training_steps = 1e5
transformer_learning_rate = 1e-5
max_length = 2**7  # 2**7 = 128
vocab_size = len(tokenizer)  # <s>I think this is the GPT-2 vocab size, or close to it</s>
embed_dim = max_length * 4  # Setting the embedding dimension to just be 4x the max length, seems like a relatively reasonable starting point, but I could be wrong...
hidden_dim = 1028  # Can't remember if this is actually needed...
num_layers = 2  # Trying to reproduce results from https://transformer-circuits.pub/2023/monosemantic-features/index.html
block_size = 2**3  # 8
num_heads = max_length % block_size  # I believe this will make it so each head (except the last one) computes the attention matrix across blocks of 8 tokens
transformer_config = {
    "seed": seed,
    "name": model_name,
    "batch_size": batch_size,
    "max_length": max_length,
    "vocabulary_size": vocab_size,
    "embed_dim": embed_dim,
    "hidden_dim": hidden_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "nonlinearity": nn.ReLU,  # Will change this to gelu/other later.
    "dtype": DTYPES["bf16"],
}
transformer_config = TransformerConfig(**transformer_config)

# Autoencoder Configuration
batch_size = 4096  # tokens?
buffer_mult = 384  #??
autoencoder_learning_rate = 1e-4
num_tokens = int(2e9)  #??
l1_coeff = 3e-4
beta1 = 0.9
beta2 = 0.99
dict_mult = 8  # <- scaling factor for dictionary learning
seq_len = max_length
d_mlp = transformer_config.hidden_dim #2048
enc_dtype = transformer_config.dtype #"fp32"
remove_rare_dir = False
autoencoder_config = {
    "seed": seed,
    "batch_size": batch_size,
    "buffer_mult": buffer_mult,
    "autoencoder_learning_rate": autoencoder_learning_rate,
    "num_tokens": num_tokens,
    "l1_coeff": l1_coeff,
    "beta1": beta1,
    "beta2": beta2,
    "dict_mult": dict_mult,
    "seq_len": seq_len,
    "d_mlp": d_mlp,
    "enc_dtype": enc_dtype,
    "remove_rare_dir": remove_rare_dir,

}
autoencoder_config["model_batch_size"] = 64
autoencoder_config["buffer_size"] = autoencoder_config["batch_size"] * autoencoder_config["buffer_mult"]
autoencoder_config["buffer_batches"] = autoencoder_config["buffer_size"] // autoencoder_config["seq_len"]
autoencoder_config = AutoEncoderConfig(**autoencoder_config)


# - [X] TODO: Find dataset: https://huggingface.co/datasets
# - [ ] TODO: Grab tiny shakespear dataset and process it for faster development iterations.
dataset_name = "piqa"
#dataset_name = "Skylion007/openwebtext"
data = load_dataset(dataset_name, split="sample", trust_remote_code=True)

# # data = load_dataset("NeelNanda/c4-code-20k", split="train")
# tokenized_data = utils.tokenize_and_concatenate(data, tokenizer, max_length=512)
# tokenized_data = tokenized_data.shuffle(42)
# all_tokens = tokenized_data["tokens"]

# Configuration initialization.
config = Config(
    # - [X] TODO: Fill this out as much as possible at once during initialization!
    transformer=transformer_config,
    autoencoder=autoencoder_config,
    data=DataConfig(
        **{
            "name": dataset_name,
            "data": data
          }
    ),
    training=TrainingConfig(
        **{
            "batch_size": batch_size,
            "num_steps": training_steps,
            "transformer_learning_rate": transformer_learning_rate,
            "autoencoder_learning_rate": autoencoder_learning_rate,
            "optimizer": None,

          }
    )
)

TypeError: ModelConfig.__init__() got an unexpected keyword argument 'buffer_mult'

## 1.4 - `Models` - Define the Models


### 1.4.1 - Autoencoder

In [None]:
# 1. - [ ] TODO: Reimplement my own AutoEncoder from scratch.
class AutoEncoder(nn.Module):
    def __init__(self, config: AutoEncoderConfig, **kwargs):
        super().__init__()
        self.config = config

    def forward(self, x):
        ...

# 2. - [X] TODO: Understand the autoencoder, both in terms of architecture as well as in terms of how we are trying to use it.
#       This will be important in understanding the link between this aspect of MI and the mathematics that describes SUSY.




### 1.4.2 - Transformer

In [None]:
# TODOs:
# 1. - [X] TODO: Try and build the transformer from scratch without any help, as an exercise.
# 2. - [ ] TODO: After getting as far as possible from memory, reference some examples to finish implementing the code below.

###################################################################################

# Attempt #1:
# 1. - [X] TODO: Initialize the following classes:
        # 1.1 - [X] TODO: Initialize the class for Attention
        # 1.2 - [X] TODO: Initialize the class for MLP
        # 1.3 - [X] TODO: Initialize MHSA
        # 1.4 - [X] TODO: Initialize TransformerBlock(MHSA)
        # 1.5 - [X] TODO: Initialize TransformerLanguageModel
# 2. - [ ] TODO: DEBUG
#---------------------------------------------------------------------------------#
class Attention(nn.Module):
    def __init__(self, config: TransformerConfig, **kwargs):
        super().__init__()
        self.config = config
        ...

    def forward(self, x):
        ...

# - [X] TODO: 2. Implement the class for MLP
# This one should be pretty easy, hopefully...
class MLP(nn.Module):
    def __init__(self, config: TransformerConfig, **kwargs):
        super().__init__()
        self.config = config
        if kwargs.get("type", "attn")=="atn":
            self.fc = nn.Linear(config.hidden_dim, config.hidden_dim)
        else:
            self.fc = nn.Linear(config.hidden_dim, config.vocabulary_size)

    def forward(self, x):
        return self.fc(x)


class MHSA(nn.Module):
    def __init__(self, config: TransformerConfig, **kwargs):
        super().__init__()
        self.config = config
        ...

    def forward(self, x):
      ...

class TransformerBlock(MHSA):
    def __init__(self, config: TransformerConfig, **kwargs):
        super().__init__(config=config)
        self.config = config
        ...

    def forward(self, x):
        ...





# - [X] TODO: Implement the class for Transformer.
# - [X] TODO: Try and remember how to construct the rest of the transformer...
# - [ ] TODO: Look at the docs and then try to finish up what we have so far...
# - [ ] TODO: Integrate the AutoEncoder directly into the Transformer...
class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig|Dict, **kwargs):
        super().__init__()
        if not isinstance(config, TransformerConfig):
          config = TransformerConfig(**config)
        self.config = config
        self.layers = nn.Sequential(
            nn.Embedding(num_embeddings=config.vocabulary_size, embedding_dim=config.embed_dim),
        )


    def forward(self, x):
        ...

    def generate(self, inp, **kwargs):
        ...

###################################################################################

## 1.5 - `Function Defs` - Utils


### 1.4.1 - `Function Defs` - "Standard Lib"

#### `Function Def` - Train the Language Model

In [None]:
# - [X] TODO: Implement training logic
def train(model, tokenizer, data, config: TrainingConfig|None=None, **kwargs):
    device = config.device or kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_data = data[:round(0.8*len(data))]
    val_data = data[round(0.8*len(data)):round(0.9*len(data))]
    test_data = data[round(0.9*len(data)):]

    eval_every_n_batches = kwargs.get("eval_every_n_batches", 1e4)
    max_steps = kwargs.get("num_steps", 1e5)

    optimizer = kwargs.get("optimizer")

    loss_fn = kwargs.get("loss_fn", F.binary_cross_entropy_with_logits)

    for i in tqdm(range(len(train_data)), desc="Running training loop..."):
        model.train()
        train_example = train_data[i]

        model_input = tokenizer.encode(train_example["text"], padding="max_length", truncate=True, max_length=512, return_tensors="pt").to(device)
        model_output = model(**model_input)


        if i % eval_every_n_batches == 0:
            with torch.no_grad():
                for j in tqdm(range(len(val_data)), desc="Running validation loop..."):
                    val_example = val_data[j]




        if i >= max_steps:
          break

      return test_data


#### `Function Def` - Evaluate the Language Model

In [None]:
@torch.no_grad()
def evaluate(model, data):
    ...

### 1.4.2 - `Function Defs` - Mechanistic Interpretibility

#### 1.4.2.1 - `Function Defs` - Get Reconstruction Loss

In [None]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@torch.no_grad()
def get_recons_loss(num_batches=5, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss

#### 1.4.2.2 - `Function Defs` - Get Frequencies

In [None]:
# Frequency
@torch.no_grad()
def get_freqs(model, all_tokens, num_batches=25, local_encoder=None, autoencoder_config: ModelConfig):
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).cuda()
    total = 0
    for i in tqdm(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:autoencoder_config.model_batch_size]]

        _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
        mlp_acts = cache[utils.get_act_name("post", 0)]
        mlp_acts = mlp_acts.reshape(-1, d_mlp)

        hidden = local_encoder(mlp_acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

#### 1.4.2.3 - `Functions Defs` - Visualise Feature Utils

In [None]:
from html import escape
import colorsys

from IPython.display import display

SPACE = "·"
NEWLINE="↩"
TAB = "→"

def create_html(strings, values, max_value=None, saturation=0.5, allow_different_length=False, return_string=False):
    # escape strings to deal with tabs, newlines, etc.
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [
        s.replace("\n", f"{NEWLINE}<br/>").replace("\t", f"{TAB}&emsp;").replace(" ", "&nbsp;")
        for s in escaped_strings
    ]

    if isinstance(values, torch.Tensor) and len(values.shape)>1:
        values = values.flatten().tolist()

    if not allow_different_length:
        assert len(processed_strings) == len(values)

    # scale values
    if max_value is None:
        max_value = max(max(values), -min(values))+1e-3
    scaled_values = [v / max_value * saturation for v in values]

    # create html
    html = ""
    for i, s in enumerate(processed_strings):
        if i<len(scaled_values):
            v = scaled_values[i]
        else:
            v = 0
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(
            hue, v, 1
        )  # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = "#%02x%02x%02x" % (
            int(rgb_color[0] * 255),
            int(rgb_color[1] * 255),
            int(rgb_color[2] * 255),
        )
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'
    if return_string:
        return html
    else:
        display(HTML(html))

def basic_feature_vis(encoder, text, feature_index, max_val=0):
    feature_in = encoder.W_enc[:, feature_index]
    feature_bias = encoder.b_enc[feature_index]
    _, cache = model.run_with_cache(text, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)][0]
    feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias)
    if max_val==0:
        max_val = max(1e-7, feature_acts.max().item())
        # print(max_val)
    # if min_val==0:
    #     min_val = min(-1e-7, feature_acts.min().item())
    return basic_token_vis_make_str(text, feature_acts, max_val)
def basic_token_vis_make_str(model, strings, values, max_val=None):
    if not isinstance(strings, list):
        strings = model.to_str_tokens(strings)
    values = utils.to_numpy(values)
    if max_val is None:
        max_val = values.max()
    # if min_val is None:
    #     min_val = values.min()
    header_string = f"<h4>Max Range <b>{values.max():.4f}</b> Min Range: <b>{values.min():.4f}</b></h4>"
    header_string += f"<h4>Set Max Range <b>{max_val:.4f}</b></h4>"
    # values[values>0] = values[values>0]/ma|x_val
    # values[values<0] = values[values<0]/abs(min_val)
    body_string = create_html(strings, values, max_value=max_val, return_string=True)
    return header_string + body_string
# display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1)))
# # %%
# The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
import gradio as gr
try:
    demos[0].close()
except:
    pass
demos = [None]
def make_feature_vis_gradio(model, feature_id, starting_text=None, batch=None, pos=None):
    if starting_text is None:
        starting_text = model.to_string(all_tokens[batch, 1:pos+1])
    try:
        demos[0].close()
    except:
        pass
    with gr.Blocks() as demo:
        gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
        # The input elements
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label="Text", value=starting_text)
                # Precision=0 makes it an int, otherwise it's a float
                # Value sets the initial default value
                feature_index = gr.Number(
                    label="Feature Index", value=feature_id, precision=0
                )
                # # If empty, these two map to None
                max_val = gr.Number(label="Max Value", value=None)
                # min_val = gr.Number(label="Min Value", value=None)
                inputs = [text, feature_index, max_val]
        with gr.Row():
            with gr.Column():
                # The output element
                out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(starting_text, feature_id))
        for inp in inputs:
            inp.change(basic_feature_vis, inputs, out)
    demo.launch(share=True)
    demos[0] = demo

##### `Function Def` - Inspecting Top Logits

In [None]:
SPACE = "·"
NEWLINE="↩"
TAB = "→"
def process_token(s):
    if isinstance(s, torch.Tensor):
        s = s.item()
    if isinstance(s, np.int64):
        s = s.item()
    if isinstance(s, int):
        s = model.to_string(s)
    s = s.replace(" ", SPACE)
    s = s.replace("\n", NEWLINE+"\n")
    s = s.replace("\t", TAB)
    return s

def process_tokens(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [process_token(s) for s in l]

def process_tokens_index(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [f"{process_token(s)}/{i}" for i,s in enumerate(l)]

def create_vocab_df(logit_vec, make_probs=False, full_vocab=None):
    if full_vocab is None:
        full_vocab = process_tokens(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))
    vocab_df = pd.DataFrame({"token": full_vocab, "logit": utils.to_numpy(logit_vec)})
    if make_probs:
        vocab_df["log_prob"] = utils.to_numpy(logit_vec.log_softmax(dim=-1))
        vocab_df["prob"] = utils.to_numpy(logit_vec.softmax(dim=-1))
    return vocab_df.sort_values("logit", ascending=False)

#### 1.4.2.4 - `Function Defs` - Make Token DataFrame

In [None]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]
def make_token_df(tokens, len_prefix=5, len_suffix=1):
    str_tokens = [process_tokens(model.to_str_tokens(t)) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    batch = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        # context.append([])
        # batch.append([])
        # pos.append([])
        # label.append([])
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            batch.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        batch=batch,
        pos=pos,
        label=label,
    ))

# 2. `Implementation` - "Standard Model" 😉 Training and Evaluation

## 2.1 - Load the Models

### 2.1.1 - Loading the Transformer

In [None]:
# - [X] TODO: Load the model.
model = Transformer(config=config.model)

# - [X] TODO: Understand HookedTransformer from TransformerLens: https://github.com/neelnanda-io/TransformerLens
# model = HookedTransformer.from_pretrained("gelu-1l").to(DTYPES[config["enc_dtype"]])

### 2.1.2 - Loading the Autoencoder

In [None]:
# - [ ] TODO: Reimplement autoencoder loading using custom model.
# auto_encoder_run = "run1" # @param ["run1", "run2"]
# encoder = AutoEncoder.load_from_hf(auto_encoder_run)

## 2.3 - Train the model

In [None]:
config.training.optimizer = torch.optim.AdamW(
    params=model.parameters(), lr=config.training.learning_rate)
test_data = train(
    model=model,
    tokenizer=tokenizer,
    data=data,
    **config.training.to_dict()
)

In [None]:
data

## 2.4 - Inspect training metrics and perform abbreviated evaluation

In [None]:
eval_args = {
    # - [ ] TODO: Figure out what evaluation arguments are needed...
}
evaluate(model, test_data, **eval_args)

# 3. `Implementation` - Mechanistic Interpretability Analysis

### 3.1 - Using the Autoencoder


In [None]:
_ = get_recons_loss(num_batches=5, local_encoder=encoder)

### 3.2 - Rare Features Are All The Same

For each feature we can get the frequency at which it's non-zero (per token, averaged across a bunch of batches), and plot a histogram

In [None]:
freqs = get_freqs(num_batches = 50, local_encoder = encoder)

In [None]:
# Add 1e-6.5 so that dead features show up as log_freq -6.5
log_freq = (freqs + 10**-6.5).log10()
px.histogram(utils.to_numpy(log_freq), title="Log Frequency of Features", histnorm='percent')

We see that it's clearly bimodal! Let's define rare features as those with freq < 1e-4, and look at the cosine sim of each feature with the average rare feature - we see that almost all rare features correspond to this feature!

In [None]:
is_rare = freqs < 1e-4
rare_enc = encoder.W_enc[:, is_rare]
rare_mean = rare_enc.mean(-1)
px.histogram(utils.to_numpy(rare_mean @ encoder.W_enc / rare_mean.norm() / encoder.W_enc.norm(dim=0)), title="Cosine Sim with Ave Rare Feature", color=utils.to_numpy(is_rare), labels={"color": "is_rare", "count": "percent", "value": "cosine_sim"}, marginal="box", histnorm="percent", barmode='overlay')

### 3.3 - Interpreting A Feature

Let's go and investigate a non rare feature, feature 7

In [None]:
feature_id = 7 # @param {type:"number"}
batch_size = 128 # @param {type:"number"}

print(f"Feature freq: {freqs[7].item():.4f}")

Let's run the model on some text and then use the autoencoder to process the MLP activations

In [None]:
tokens = all_tokens[:batch_size]
_, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)]
mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])
loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = encoder(mlp_acts_flattened)
# This is equivalent to:
# hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
print("hidden_acts.shape", hidden_acts.shape)

We can now sort and display the top tokens, and we see that this feature activates on text like " and I" (ditto for other connectives and pronouns)! It seems interpretable!

**Aside:** Note on how to read the context column:

A line like "·himself·as·democratic·socialist·and|·he|·favors" means that the preceding 5 tokens are " himself as democratic socialist and", the current token is " he" and the next token is " favors".  · are spaces, ↩ is a newline.

This gets a bit confusing for this feature, since the pipe separators look a lot like a capital I


In [None]:
token_df = make_token_df(tokens)
token_df["feature"] = utils.to_numpy(hidden_acts[:, feature_id])
token_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

It's easy to misread evidence like the above, so it's useful to take some text and edit it and see how this changes the model's activations. Here's a hacky interactive tool to play around with some text.

In [None]:
model.cfg

In [None]:
s = "The 1899 Kentucky gubernatorial election was held on November 7, 1899. The Republican incumbent, William Bradley, was term-limited. The Democrats chose William Goebel. Republicans nominated William Taylor. Taylor won by a vote of 193,714 to 191,331. The vote was challenged on grounds of voter fraud, but the Board of Elections, though stocked with pro-Goebel members, certified the result. Democratic legislators began investigations, but before their committee could report, Goebel was shot by an unknown assassin (event pictured) on January 30, 1900. Democrats voided enough votes to swing the election to Goebel, Taylor was deposed, and Goebel was sworn into office on January 31. He died on February 3. The lieutenant governor of Kentucky, J. C. W. Beckham, became governor, and battled Taylor in court. Beckham won on appeal, and Taylor fled to Indiana, fearing arrest as an accomplice. The only persons convicted in connection with the killing were later pardoned; the assassin's identity remains a mystery"
t = model.to_tokens(s)
print(t)

In [None]:

starting_text = "Hero and I will head to Samantha and Mark's, then he and she will. Then I or you" # @param {type:"string"}
make_feature_vis_gradio(feature_id, starting_text)

# 4. `Conjecture` - Connections with Supersymmetry

## 4.1 - Notation

## 4.2 - Motivation

## 4.3 Formalism

# 5. `Conjecture` - Next Steps

# 6. References