<a href="https://colab.research.google.com/github/zwimpee/1L-Sparse-Autoencoder/blob/main/dev.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# *📖 Dictionary Learning, 🕵 Interpretability, 🌜 🌛 and Supersymmetry*

---

## Introduction

This notebook documents a personal exploration into intriguing parallels between the fields of **Deep Learning** and **Supersymmetry**. Recent developments in deep learning, specifically in the area of **mechanistic interpretability**, mirror theoretical constructs in **supersymmetry** (SUSY), suggesting that mathematical frameworks used to describe particle physics might enrich our understanding of neural networks.

## Motivation

In the realms of theoretical physics and artificial intelligence, complex systems are often broken down through similar mathematical lenses—albeit with different end goals. This project hypothesizes that the methods used to dissect and interpret neural networks in deep learning could benefit from supersymmetric theories, particularly through the lens of **holoraumy tensors**.

**Supersymmetry** posits a correspondence between two basic classes of particles: bosons and fermions. Holoraumy tensors, as discussed in this [paper](https://arxiv.org/abs/1906.02971), provide a framework for understanding transformations within supersymmetrical systems. These tensors describe the electromagnetic-duality rotations that are crucial for linking various properties of supermultiplets in four-dimensional, N=1 supersymmetry.


**Mechanistic Interpretability** (MI) focuses on understanding the precise mechanisms through which neural networks process inputs to produce outputs. This often involves the manipulation of high-dimensional spaces to extract meaningful patterns—akin to transforming and reducing dimensions in physical theories.

Recently, I've noticed that the advancements in interpretability, particularly those by [Chris Olah](https://scholar.google.com/citations?user=6dskOSUAAAAJ&hl=en) at [Anthropic](https://www.anthropic.com/research) on [Transformer Circuits](https://transformer-circuits.pub/), closely parallel the concepts I've worked with on the paper linked above as part of research group led by [Dr. Sylvester James Gates (Jim)](https://twitter.com/dr_jimgates) at Brown University.



These observations suggest that the mathematical principles of supersymmetry and quantum field theory could directly apply to transformer-based neural networks, potentially deepening our understanding of model mechanics and guiding models to align more closely with human interests through techniques like ablation and masking.

## Hypothesis and Objectives

**Hypothesis**: The electromagnetic-duality rotations and the structures of holoraumy tensors in supersymmetry can analogously describe transformations within deep learning models, particularly in how neural networks achieve dimensionality reduction and feature extraction.

**Objectives**:
1. **Define and Formalize**: Construct a detailed mathematical analogy between the operations in neural networks and supersymmetric transformations, focusing on the role of holoraumy tensors in describing system dynamics.
2. **Analogize and Model**: Develop a model to demonstrate how supersymmetric principles, particularly those involving dimensional transformations and symmetry, could theoretically underpin neural network operations.

## Definitions and Theoretical Background
### 1. Mechanistic Interpretability in Deep Learning
- #### 1.1 **Dictionary Learning**
A recent technique termed **dictionary learning** by the team at Anthropic involves dimensionality manipulation where inputs are projected into a high-dimensional space to capture latent features:
$$
  \mathbf{z} = f_{\text{encode}}(\mathbf{x}; \theta_e), \\
$$
$$
\mathbf{y} = f_{\text{decode}}(\mathbf{z}; \theta_d)
$$
Here, $\mathbf{x}$ represents the input, $\mathbf{z}$ the encoded latent space, and $\mathbf{y}$ the output, with $\theta_e$ and $\theta_d$ as the encoder and decoder parameters, respectively.

### 2. Supersymmetry and Holoraumy Tensors
- #### 2.1 - Supersymmetry
Supersymmetry posits a fundamental symmetry between bosons and fermions, often involving dimensional transformations to relate these particle types through algebraic structures:
$$
Q | \text{Boson} \rangle = | \text{Fermion} \rangle, \quad Q | \text{Fermion} \rangle = | \text{Boson} \rangle
$$
Where \(Q\) is the supersymmetry generator.


## Exploratory Framework

This research will unfold through several structured phases:
1. **Review of Supersymmetric Theories**: Summarize key supersymmetric theories, especially focusing on the role and computation of holoraumy tensors in 4D, N=1 supersymmetry as detailed in the aforementioned paper.
2. **Application to Neural Networks**: Propose and formulate how these supersymmetric concepts could map onto neural network operations, particularly in interpretability and model simplification.
3. **Empirical Analysis**: Develop simulations or theoretical models to test these analogies, assessing their utility in providing new insights into neural network behavior.
4. **Validation**: Critically evaluate whether these supersymmetric approaches can be effectively integrated into current deep learning frameworks, offering tangible benefits.

#### Summary and Next Steps

This notebook aims to rigorously explore the potential foundational link between deep learning and supersymmetry, adhering to strict scientific principles throughout. Our objective is to substantiate or refute our hypothesis with robust theoretical backing and empirical evidence.

The initial step is to replicate the results from the [paper](https://transformer-circuits.pub/2023/monosemantic-features) by Bricken et al. from Anthropic.

# 1. `Setup`

## 1.1 - `Dependencies`

In [1]:
!python3 -V

Python 3.10.12


In [2]:
!pip install tiktoken
!pip install transformers
!pip install datasets

Collecting transformer_lens
  Downloading transformer_lens-1.17.0-py3-none-any.whl (137 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.1/137.1 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.29.3-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.6/297.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl (3.5 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.19.0-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/5

In [3]:
!pip install wandb -qU

In [4]:
# Log in to your W&B account
import wandb

In [5]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## 1.2 - `Imports` and `Initialization`

In [33]:
import datasets
import gradio as gr
import json
import numpy as np
import os
import plotly.express as px
import pprint
import pandas as pd
import requests
import tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass, field, fields
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
from tqdm.auto import tqdm
from typing import Any, Dict, Union


seed = 42069
DTYPES = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = "tinyshakespeare"
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data_cache_dir = "data"

In [None]:
def download_file(url: str, fname: str, chunk_size=1024):
    """Helper function to download a file from a given url"""
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get("content-length", 0))
    with open(fname, "wb") as file, tqdm(
        desc=fname,
        total=total,
        unit="iB",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)

def download_tinyshakespeare(
    data_url=data_url,
    cache_dir=data_cache_dir
) -> tuple:
    """Downloads the TinyShakespeare dataset to DATA_CACHE_DIR"""
    os.makedirs(cache_dir, exist_ok=True)

    data_filename = os.path.join(cache_dir, "tiny_shakespeare.txt")
    if not os.path.exists(data_filename):
        print(f"Downloading {data_url} to {data_filename}...")
        download_file(data_url, data_filename)
    else:
        print(f"{data_filename} already exists, skipping download...")

    return data_filename

def tokenize_tinyshakespeare(tokenizer, raw_data_filepath, cache_dir=data_cache_dir):
    encode = lambda s: tokenizer.encode(s, allowed_special={'<|endoftext|>'})
    eot = tokenizer._special_tokens['<|endoftext|>'] # end of text token
    data_filename = os.path.join(cache_dir, "tiny_shakespeare.txt")
    text = open(data_filename, 'r').read()

    # let's treat every person's statement in the dialog as a separate document
    text = "<|endoftext|>" + text
    text = text.replace('\n\n', '\n\n<|endoftext|>')

    # encode the text
    tokens = encode(text)
    tokens_np = np.array(tokens, dtype=np.int32)
    return tokens_np

In [None]:
@dataclass
class ModelConfig:
    name: str
    max_length: int
    vocabulary_size: int
    hidden_dim: int
    batch_size: int
    learning_rate: float
    dtype: Any
    device: Any
    seed: Any

@dataclass
class TransformerConfig(ModelConfig):
    num_layers: int
    num_heads: int
    embed_dim: int
    nonlinearity: Any


@dataclass
class DataConfig:
    name: str
    data: np.ndarray

@dataclass
class TrainingConfig:
    train_steps: int
    eval_steps: int
    train_steps_per_eval: int
    optimizer: Any
    loss: Any
    batch_size: int
    device: Any


@dataclass
class Config:
    model: Any
    data: Any
    training: Any


# - [ ] TODO: Finish implementing Transformer correctly.
class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** -0.5  # Scaling factor for attention scores
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, mask=None):
        Q = self.query(query)
        K = self.key(key)
        V = self.value(value)

        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention = self.softmax(scores)
        output = torch.matmul(attention, V)
        return output

class MLP(nn.Module):
    def __init__(self, config: TransformerConfig, **kwargs):
        super().__init__()
        self.config = config
        if kwargs.get("type", "attn")=="atn":
            self.fc = nn.Linear(config.hidden_dim, config.hidden_dim)
        else:
            self.fc = nn.Linear(config.hidden_dim, config.vocabulary_size)

    def forward(self, x):
        return self.fc(x)

class MHSA(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.heads = nn.ModuleList([
            Attention(config.embed_dim, config.num_heads) for _ in range(config.num_heads)
        ])
        self.linear = nn.Linear(config.num_heads * config.embed_dim, config.embed_dim)

    def forward(self, x, mask=None):
        head_outputs = [head(x, x, x, mask) for head in self.heads]
        concatenated = torch.cat(head_outputs, dim=-1)
        return self.linear(concatenated)

class TransformerBlock(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.attention = MHSA(config)
        self.mlp = MLP(config)
        self.norm1 = nn.LayerNorm(config.embed_dim)
        self.norm2 = nn.LayerNorm(config.embed_dim)

    def forward(self, x, mask=None):
        x2 = self.norm1(x)
        x = x + self.attention(x2, mask)
        x2 = self.norm2(x)
        x = x + self.mlp(x2)
        return x

class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(num_embeddings=config.vocabulary_size, embedding_dim=config.embed_dim)
        self.positional_encodings = nn.Parameter(torch.zeros(1, config.max_length, config.embed_dim))
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
        self.final_layer = nn.Linear(config.embed_dim, config.vocabulary_size)

    def forward(self, x, mask=None):
        x = self.embedding(x) + self.positional_encodings[:, :x.size(1), :]
        for layer in self.layers:
            x = layer(x, mask)
        return self.final_layer(x)

In [30]:
# Data and Tokenizer Configurations
tokenizer = tiktoken.get_encoding("gpt2")
raw_data_filepath = download_tinyshakespeare()
data = tokenize_tinyshakespeare(tokenizer, raw_data_filepath)
# tokenized_data = utils.tokenize_and_concatenate(data, tokenizer, max_length=512)
# tokenized_data = tokenized_data.shuffle(42)
# all_tokens = tokenized_data["tokens"]

# Training Configuration
batch_size = 8
# transformer_batch_size = 1
# autoencoder_batch_size = 4096  # tokens?
training_steps = 1e5
eval_steps = 5e3
train_steps_per_eval = training_steps // 10
learning_rate = 1e-5
# transformer_learning_rate = 1e-5
# autoencoder_learning_rate = 1e-4

# Model Configurations
# Transformer Configuration
max_length = 2**7  # 2**7 = 128
vocabulary_size = 50257 # len(tokenizer)  # <s>I think this is the GPT-2 vocab size, or close to it</s>
embed_dim = max_length * 4  # Setting the embedding dimension to just be 4x the max length, seems like a relatively reasonable starting point, but I could be wrong...
hidden_dim = 1024  # Can't remember if this is actually needed...
num_layers = 2  # Trying to reproduce results from https://transformer-circuits.pub/2023/monosemantic-features/index.html
block_size = 2**5  # 32, i.e. 4 heads over 128 token sequence
num_heads = max_length % block_size  # I believe this will make it so each head (except the last one) computes the attention matrix across blocks of 8 tokens
transformer_config = {
    "name": "transformer",
    "max_length": max_length,
    "vocabulary_size": vocabulary_size,
    "hidden_dim": hidden_dim,
    # "batch_size": transformer_batch_size,
    # "learning_rate": transformer_learning_rate,
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "dtype": DTYPES["bfloat16"],
    "device": device,
    "seed": seed,
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "nonlinearity": nn.ReLU,  # Will change this to gelu/other later.
}
transformer_config = TransformerConfig(**transformer_config)

# Configuration initialization.
config = Config(
    model=transformer_config,
    data=DataConfig(
        **{
            "name": dataset_name,
            "data": data
          }
    ),
    training=TrainingConfig(
        **{
            "train_steps": training_steps,
            "eval_steps": eval_steps,
            "train_steps_per_eval": train_steps_per_eval,
            "learning_rate": learning_rate,
            "batch_size": batch_size,
            "optimizer": None,
            "loss": F.binary_cross_entropy_with_logits,
            "device": device
          }
    )
)

data/tiny_shakespeare.txt already exists, skipping download...


In [None]:
model = Transformer(config=config.model)
config.training.optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=config.training.learning_rate
)

## 1.3 - `Function Defintions`


#### `Function Def` - Split the Dataset

In [None]:
def train_test_split(data, train_pct=0.8, test_pct=0.1):
    """
    Splits the data into training, validation, and testing datasets.

    Parameters:
        data (array-like): The dataset to split.
        train_pct (float): Percentage of the data to use for training.
        test_pct (float): Percentage of the data to use for testing.

    Returns:
        tuple: (train_data, val_data, test_data)
    """
    if not isinstance(data, np.ndarray):
        data = np.array(data)

    total_data_len = len(data)
    train_end = int(total_data_len * train_pct)
    test_end = int(total_data_len * (train_pct + test_pct))

    np.random.shuffle(data)  # Shuffle the data to ensure random splitting

    train_data = data[:train_end]
    val_data = data[train_end:test_end]
    test_data = data[test_end:]

    return train_data, val_data, test_data

#### `Function Def` - Train the Language Model

In [10]:
def train(
  model: Transformer,
  tokenizer: GPT2TokenizerFast,
  autoencoder: AutoEncoder,
  data: np.ndarray,
  config: TrainingConfig,
  **kwargs: Dict
):
    device = config.device if config else kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Data split
    train_data, val_data, test_data = train_test_split(data)

    # Configuration parameters
    eval_every_n_batches = config.train_steps_per_eval or kwargs.get("eval_every_n_batches", 10000)
    max_steps = config.train_steps or kwargs.get("num_steps", 100000)
    optimizer = config.optimizer or kwargs.get("optimizer", torch.optim.Adam(model.parameters(), lr=config.transformer_learning_rate if config else 1e-5))
    loss_fn = config.loss or kwargs.get("loss_fn", torch.nn.CrossEntropyLoss())

    model.train()
    for i in tqdm(range(int(max_steps)), desc="Running training loop..."):
        batch_idx = i % len(train_data)
        train_example = train_data[batch_idx]

        # model_input = tokenizer.encode(train_example["text"], padding="max_length", truncate=True, max_length=model.config.max_length, return_tensors="pt").to(device)
        model_input
        labels = model_input.clone()  # Assuming a model that outputs logits with the same shape as input
        model_output = model(model_input)

        loss = loss_fn(model_output, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (i + 1) % eval_every_n_batches == 0:
            model.eval()
            validation_loss = []
            for val_example in val_data:
                val_input = tokenizer.encode(val_example["text"], padding="max_length", truncate=True, max_length=512, return_tensors="pt").to(device)
                with torch.no_grad():
                    val_output = model(val_input)
                    val_loss = loss_fn(val_output, val_input.clone())
                    validation_loss.append(val_loss.item())
            avg_val_loss = sum(validation_loss) / len(validation_loss)
            print(f"Validation Loss at step {i+1}: {avg_val_loss}")
            model.train()

        if i >= max_steps:
            break

    # Testing loop or additional analysis could be added here
    return model, {'training_loss': loss.item(), 'validation_loss': avg_val_loss}, test_data


#### `Function Def` - Evaluate the Language Model

In [11]:
@torch.no_grad()
def evaluate(model, data):
    ...

#### `Function Defs` - Get Reconstruction Loss

In [12]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@torch.no_grad()
def get_recons_loss(model, all_tokens, num_batches, local_encoder):
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[model.config.batch_size]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss

#### `Function Defs` - Get Frequencies

In [13]:
# Frequency
@torch.no_grad()
def get_freqs(model, all_tokens, num_batches, local_encoder, config: AutoEncoderConfig):
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).cuda()
    total = 0
    for i in tqdm(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:config.batch_size]]

        _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
        mlp_acts = cache[utils.get_act_name("post", 0)]
        mlp_acts = mlp_acts.reshape(-1, d_mlp)

        hidden = local_encoder(mlp_acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

#### `Functions Defs` - Visualise Feature Utils

In [14]:
from html import escape
import colorsys

from IPython.display import display

SPACE = "·"
NEWLINE="↩"
TAB = "→"

def create_html(strings, values, max_value=None, saturation=0.5, allow_different_length=False, return_string=False):
    # escape strings to deal with tabs, newlines, etc.
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [
        s.replace("\n", f"{NEWLINE}<br/>").replace("\t", f"{TAB}&emsp;").replace(" ", "&nbsp;")
        for s in escaped_strings
    ]

    if isinstance(values, torch.Tensor) and len(values.shape)>1:
        values = values.flatten().tolist()

    if not allow_different_length:
        assert len(processed_strings) == len(values)

    # scale values
    if max_value is None:
        max_value = max(max(values), -min(values))+1e-3
    scaled_values = [v / max_value * saturation for v in values]

    # create html
    html = ""
    for i, s in enumerate(processed_strings):
        if i<len(scaled_values):
            v = scaled_values[i]
        else:
            v = 0
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(
            hue, v, 1
        )  # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = "#%02x%02x%02x" % (
            int(rgb_color[0] * 255),
            int(rgb_color[1] * 255),
            int(rgb_color[2] * 255),
        )
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'
    if return_string:
        return html
    else:
        display(HTML(html))

def basic_feature_vis(encoder, text, feature_index, max_val=0):
    feature_in = encoder.W_enc[:, feature_index]
    feature_bias = encoder.b_enc[feature_index]
    _, cache = model.run_with_cache(text, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
    mlp_acts = cache[utils.get_act_name("post", 0)][0]
    feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias)
    if max_val==0:
        max_val = max(1e-7, feature_acts.max().item())
        # print(max_val)
    # if min_val==0:
    #     min_val = min(-1e-7, feature_acts.min().item())
    return basic_token_vis_make_str(text, feature_acts, max_val)
def basic_token_vis_make_str(model, strings, values, max_val=None):
    if not isinstance(strings, list):
        strings = model.to_str_tokens(strings)
    values = utils.to_numpy(values)
    if max_val is None:
        max_val = values.max()
    # if min_val is None:
    #     min_val = values.min()
    header_string = f"<h4>Max Range <b>{values.max():.4f}</b> Min Range: <b>{values.min():.4f}</b></h4>"
    header_string += f"<h4>Set Max Range <b>{max_val:.4f}</b></h4>"
    # values[values>0] = values[values>0]/ma|x_val
    # values[values<0] = values[values<0]/abs(min_val)
    body_string = create_html(strings, values, max_value=max_val, return_string=True)
    return header_string + body_string
# display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1)))
# # %%
# The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
import gradio as gr
try:
    demos[0].close()
except:
    pass
demos = [None]
def make_feature_vis_gradio(model, feature_id, starting_text=None, batch=None, pos=None):
    if starting_text is None:
        starting_text = model.to_string(all_tokens[batch, 1:pos+1])
    try:
        demos[0].close()
    except:
        pass
    with gr.Blocks() as demo:
        gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
        # The input elements
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label="Text", value=starting_text)
                # Precision=0 makes it an int, otherwise it's a float
                # Value sets the initial default value
                feature_index = gr.Number(
                    label="Feature Index", value=feature_id, precision=0
                )
                # # If empty, these two map to None
                max_val = gr.Number(label="Max Value", value=None)
                # min_val = gr.Number(label="Min Value", value=None)
                inputs = [text, feature_index, max_val]
        with gr.Row():
            with gr.Column():
                # The output element
                out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(starting_text, feature_id))
        for inp in inputs:
            inp.change(basic_feature_vis, inputs, out)
    demo.launch(share=True)
    demos[0] = demo

##### `Function Def` - Inspecting Top Logits

In [15]:
SPACE = "·"
NEWLINE="↩"
TAB = "→"
def process_token(s):
    if isinstance(s, torch.Tensor):
        s = s.item()
    if isinstance(s, np.int64):
        s = s.item()
    if isinstance(s, int):
        s = model.to_string(s)
    s = s.replace(" ", SPACE)
    s = s.replace("\n", NEWLINE+"\n")
    s = s.replace("\t", TAB)
    return s

def process_tokens(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [process_token(s) for s in l]

def process_tokens_index(l):
    if isinstance(l, str):
        l = model.to_str_tokens(l)
    elif isinstance(l, torch.Tensor) and len(l.shape)>1:
        l = l.squeeze(0)
    return [f"{process_token(s)}/{i}" for i,s in enumerate(l)]

def create_vocab_df(logit_vec, make_probs=False, full_vocab=None):
    if full_vocab is None:
        full_vocab = process_tokens(model.to_str_tokens(torch.arange(model.cfg.d_vocab)))
    vocab_df = pd.DataFrame({"token": full_vocab, "logit": utils.to_numpy(logit_vec)})
    if make_probs:
        vocab_df["log_prob"] = utils.to_numpy(logit_vec.log_softmax(dim=-1))
        vocab_df["prob"] = utils.to_numpy(logit_vec.softmax(dim=-1))
    return vocab_df.sort_values("logit", ascending=False)

#### `Function Defs` - Make Token DataFrame

In [16]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]
def make_token_df(tokens, len_prefix=5, len_suffix=1):
    str_tokens = [process_tokens(model.to_str_tokens(t)) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    batch = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        # context.append([])
        # batch.append([])
        # pos.append([])
        # label.append([])
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            batch.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        batch=batch,
        pos=pos,
        label=label,
    ))

# 2. `Implementation`

## 2.1 - `Training`

In [27]:
transformer, training_metrics, test_data = train(
    model=model,
    tokenizer=tokenizer,
    data=data,
    config=config.training
)

TypeError: unhashable type: 'slice'

## 2.2 - `Evaluation`

In [None]:
eval_args = {
    # - [ ] TODO: Figure out what evaluation arguments are needed...
}
evaluate(model, test_data, **eval_args)

# 3. `Mechanistic Interpretability Analysis`

### 3.1 - Using the Autoencoder


In [None]:
_ = get_recons_loss(num_batches=5, local_encoder=autoencoder)

### 3.2 - Rare Features Are All The Same

For each feature we can get the frequency at which it's non-zero (per token, averaged across a bunch of batches), and plot a histogram

In [None]:
freqs = get_freqs(num_batches = 50, local_encoder = autoencoder)

In [None]:
# Add 1e-6.5 so that dead features show up as log_freq -6.5
log_freq = (freqs + 10**-6.5).log10()
px.histogram(utils.to_numpy(log_freq), title="Log Frequency of Features", histnorm='percent')

We see that it's clearly bimodal! Let's define rare features as those with freq < 1e-4, and look at the cosine sim of each feature with the average rare feature - we see that almost all rare features correspond to this feature!

In [None]:
is_rare = freqs < 1e-4
rare_enc = autoencoder.W_enc[:, is_rare]
rare_mean = rare_enc.mean(-1)
px.histogram(utils.to_numpy(rare_mean @ autoencoder.W_enc / rare_mean.norm() / autoencoder.W_enc.norm(dim=0)), title="Cosine Sim with Ave Rare Feature", color=utils.to_numpy(is_rare), labels={"color": "is_rare", "count": "percent", "value": "cosine_sim"}, marginal="box", histnorm="percent", barmode='overlay')

### 3.3 - Interpreting A Feature

Let's go and investigate a non rare feature, feature 7

In [None]:
feature_id = 7 # @param {type:"number"}
batch_size = 128 # @param {type:"number"}

print(f"Feature freq: {freqs[7].item():.4f}")

Let's run the model on some text and then use the autoencoder to process the MLP activations

In [None]:
tokens = all_tokens[:batch_size]
_, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
mlp_acts = cache[utils.get_act_name("post", 0)]
mlp_acts_flattened = mlp_acts.reshape(-1, config.autoencoder.d_mlp)
loss, x_reconstruct, hidden_acts, l2_loss, l1_loss = autoencoder(mlp_acts_flattened)
# This is equivalent to:
# hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
print("hidden_acts.shape", hidden_acts.shape)

We can now sort and display the top tokens, and we see that this feature activates on text like " and I" (ditto for other connectives and pronouns)! It seems interpretable!

**Aside:** Note on how to read the context column:

A line like "·himself·as·democratic·socialist·and|·he|·favors" means that the preceding 5 tokens are " himself as democratic socialist and", the current token is " he" and the next token is " favors".  · are spaces, ↩ is a newline.

This gets a bit confusing for this feature, since the pipe separators look a lot like a capital I


In [None]:
token_df = make_token_df(tokens)
token_df["feature"] = utils.to_numpy(hidden_acts[:, feature_id])
token_df.sort_values("feature", ascending=False).head(20).style.background_gradient("coolwarm")

It's easy to misread evidence like the above, so it's useful to take some text and edit it and see how this changes the model's activations. Here's a hacky interactive tool to play around with some text.

In [None]:
model.cfg

In [None]:
s = "The 1899 Kentucky gubernatorial election was held on November 7, 1899. The Republican incumbent, William Bradley, was term-limited. The Democrats chose William Goebel. Republicans nominated William Taylor. Taylor won by a vote of 193,714 to 191,331. The vote was challenged on grounds of voter fraud, but the Board of Elections, though stocked with pro-Goebel members, certified the result. Democratic legislators began investigations, but before their committee could report, Goebel was shot by an unknown assassin (event pictured) on January 30, 1900. Democrats voided enough votes to swing the election to Goebel, Taylor was deposed, and Goebel was sworn into office on January 31. He died on February 3. The lieutenant governor of Kentucky, J. C. W. Beckham, became governor, and battled Taylor in court. Beckham won on appeal, and Taylor fled to Indiana, fearing arrest as an accomplice. The only persons convicted in connection with the killing were later pardoned; the assassin's identity remains a mystery"
t = model.to_tokens(s)
print(t)

In [None]:

starting_text = "Hero and I will head to Samantha and Mark's, then he and she will. Then I or you" # @param {type:"string"}
make_feature_vis_gradio(feature_id, starting_text)

# 4. `Supersymmetry`

## 4.1 - Notation

## 4.2 - Motivation

## 4.3 Formalism

# 5. `Next Steps`

# 6. `References`