# Load LLM

In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

# small model
# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# big model
model_name = "mistralai/Mistral-7B-Instruct-v0.2"


# Device setup for gpu (I've ran this on colab)
device = torch.device("gpu" if torch.backends.mps.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
)

model.to(device)
model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]



special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): MistralRMSNorm((4096,)

## Preparation for Transcoder training
- Get the $MLP_{input}^{l}$ and $MLP_{out}^{l}$
  - Since transcoder models

    $ z^{l} = W_{enc}MLP_{input}^{l} + b_{enc} $

    $ \hat{MLP_{out}^{l}} = W_{dec}MLP_{out}^{l} + b_{dec} $

  

In [5]:
layer_idx = 15 # Choose random layer

mlp_inputs = []
mlp_outputs = []

def get_activations_hook(module, input, output):
    # input is a tuple, we want the first element
    mlp_inputs.append(input[0].detach().cpu())
    mlp_outputs.append(output.detach().cpu())

# Register hook on the specific MLP layer
# For Mistral, the MLP is usually model.model.layers[i].mlp
hook_handle = model.model.layers[layer_idx].mlp.register_forward_hook(get_activations_hook)

# Run some dummy text through the model to collect data
texts = ["The quick brown fox jumps over the lazy dog.", "Mechanistic interpretability is fascinating."]
for text in texts:
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        model(**inputs)

# Clean up the hook
hook_handle.remove()

# Concatenate all collected tokens: shape will be (total_tokens, hidden_size)
X_in = torch.cat(mlp_inputs, dim=1).squeeze(0)
X_out = torch.cat(mlp_outputs, dim=1).squeeze(0)

# Reshape to flatten batch and sequence length
X_in = X_in.view(-1, X_in.shape[-1]).to(device).float()
X_out = X_out.view(-1, X_out.shape[-1]).to(device).float()

print(f"Collected {X_in.shape[0]} tokens of activation data.")

Collected 21 tokens of activation data.


## Define Transcoder Model

$ z^{l} = W_{enc}MLP_{input}^{l} + b_{enc} $

$ \hat{MLP_{out}^{l}} = W_{dec}MLP_{out}^{l} + b_{dec} $


  Just one layer for each encoder and decoder

In [3]:
import torch.nn as nn

class Transcoder(nn.Module):
    def __init__(self, input_dim, dict_size):
        super().__init__()
        self.input_dim = input_dim
        self.dict_size = dict_size

        # Encoder: maps MLP input to sparse features
        self.W_enc = nn.Linear(input_dim, dict_size, bias=False)
        self.b_enc = nn.Parameter(torch.zeros(dict_size))

        # Decoder: maps sparse features to MLP output
        self.W_dec = nn.Linear(dict_size, input_dim, bias=False)
        self.b_dec = nn.Parameter(torch.zeros(input_dim))

        # Tie decoder weights to unit norm (standard practice in dictionary learning I guess)
        # Reference:
        with torch.no_grad():
            self.W_dec.weight.data = F.normalize(self.W_dec.weight.data, dim=0)

    def encode(self, x_in):
        # Shift by decoder bias before encoding
        shifted_x = x_in - self.b_dec
        # Calculate pre-activations and apply ReLU for sparsity
        pre_acts = self.W_enc(shifted_x) + self.b_enc
        return F.relu(pre_acts)

    def decode(self, f):
        return self.W_dec(f) + self.b_dec

    def forward(self, x_in):
        f = self.encode(x_in)
        x_out_pred = self.decode(f)
        return x_out_pred, f

## Train Transcoder

### Loss
- $L_{rec}$=Reconstruction
  - $||MLP_{out}^{l} - \hat{MLP_{out}^{l}}||_{2}^{2}$
- $L_{spars}$ = L1
  - $||\hat{MLP_{out}^{l}}||$
- $\textbf{min } L = L_{rec} + \lambda * L_{spas}$
### Configs
- epochs 100
- batch size 256
- l1_coefficient 1e-3

In [6]:
hidden_size = model.config.hidden_size # 4096 for Mistral-7B
expansion_factor = 4 # How much larger the sparse dictionary is
dict_size = hidden_size * expansion_factor

transcoder = Transcoder(input_dim=hidden_size, dict_size=dict_size).to(device)
optimizer = torch.optim.Adam(transcoder.parameters(), lr=1e-3)

epochs = 100
batch_size = 256
l1_coefficient = 1e-3 # Adjust based on desired sparsity

transcoder.train()
for epoch in range(epochs):
    permutation = torch.randperm(X_in.size()[0])
    epoch_loss = 0

    for i in range(0, X_in.size()[0], batch_size):
        indices = permutation[i:i+batch_size]
        batch_x_in, batch_x_out = X_in[indices], X_out[indices]

        optimizer.zero_grad()

        # Forward pass through transcoder
        x_out_pred, f = transcoder(batch_x_in)

        # Calculate Loss
        mse_loss = F.mse_loss(x_out_pred, batch_x_out)
        l1_loss = f.norm(p=1, dim=-1).mean()
        loss = mse_loss + l1_coefficient * l1_loss

        loss.backward()

        # Normalize decoder weights after gradient step (Dictionary learning standard)
        with torch.no_grad():
            transcoder.W_dec.weight.data = F.normalize(transcoder.W_dec.weight.data, dim=0)

        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 0:
        print(f"Epoch {epoch} | Loss: {epoch_loss / (X_in.size()[0]/batch_size):.4f}")

Epoch 0 | Loss: 94.6644
Epoch 10 | Loss: 0.0212
Epoch 20 | Loss: 0.0094
Epoch 30 | Loss: 0.0096
Epoch 40 | Loss: 0.0097
Epoch 50 | Loss: 0.0097
Epoch 60 | Loss: 0.0097
Epoch 70 | Loss: 0.0096
Epoch 80 | Loss: 0.0096
Epoch 90 | Loss: 0.0095


## Replacing Transcoder to LLM
Right now we are replacing layer_index = 15 for this example.

In [9]:
class PatchedMLP(nn.Module):
    def __init__(self, transcoder):
        super().__init__()
        self.transcoder = transcoder

    def forward(self, x):
        # 1. Save the original dtype the LLM is using (BFloat16)
        original_dtype = x.dtype

        # 2. Find out what dtype the transcoder is using (Float32)
        transcoder_dtype = next(self.transcoder.parameters()).dtype

        # 3. Cast the input to match the transcoder
        x_cast = x.to(transcoder_dtype)

        # 4. Run the transcoder
        x_out_pred, _ = self.transcoder(x_cast)

        # 5. Cast the output back to the LLM's original dtype before returning
        return x_out_pred.to(original_dtype)

# Now, patch it in exactly as you did before:
transcoder.eval()
model.model.layers[layer_idx].mlp = PatchedMLP(transcoder)

# Your generate function should now run perfectly!
test_inputs = tokenizer("Mechanistic interpretability is", return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model.generate(**test_inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Mechanistic interpretability is a crucial aspect of the work in the fields of the fields of,, the
usa


## Visualize Activatiosn
Here, I just vibe coded for visualizing if certain feature=42 is activating or not.

In [13]:
import torch
from IPython.display import display, HTML

def visualize_feature_activation(text, feature_idx, model, transcoder, tokenizer, layer_idx, device="mps"):
    """
    Visualizes transcoder feature activations with an interactive Jupyter UI.
    Includes proper hook cleanup, mixed-precision handling, and improved CSS.
    """
    inputs = tokenizer(text, return_tensors="pt").to(device)

    # Safely decode tokens
    tokens = [tokenizer.decode([t]) for t in inputs["input_ids"][0]]

    # 1. Safely register and remove the hook using try/finally
    mlp_inputs = []
    def hook_fn(module, input, output):
        mlp_inputs.append(input[0].detach())

    target_module = model.model.layers[layer_idx].mlp
    hook = target_module.register_forward_hook(hook_fn)

    try:
        with torch.no_grad():
            model(**inputs)
    finally:
        # Guarantee the hook is removed even if the model throws an error
        hook.remove()

    # 2. Handle mixed precision (cast LLM activations to Transcoder's dtype)
    X_in = mlp_inputs[0].squeeze(0)
    transcoder_dtype = next(transcoder.parameters()).dtype
    X_in_cast = X_in.to(transcoder_dtype)

    # 3. Get Transcoder features
    with torch.no_grad():
        f = transcoder.encode(X_in_cast)

    feature_acts = f[:, feature_idx].cpu().numpy()

    # 4. Normalize
    max_act = feature_acts.max()
    normalized_acts = feature_acts / max_act if max_act > 0 else feature_acts

    # 5. Build a clean, interactive HTML/CSS UI
    html_string = f"""
    <div style="font-family: monospace; max-width: 800px; padding: 15px; border: 1px solid #ddd; border-radius: 8px; background-color: #fafafa;">
        <h3 style="margin-top: 0; color: #333;">Feature {feature_idx} Activations</h3>
        <p style="font-size: 0.9em; color: #666; margin-bottom: 15px;">Max Activation: <strong>{max_act:.4f}</strong></p>
        <div style="display: flex; flex-wrap: wrap; gap: 4px; line-height: 1.5;">
    """

    for token, act, norm_act in zip(tokens, feature_acts, normalized_acts):
        # Clean special HTML characters
        clean_token = token.replace('<', '&lt;').replace('>', '&gt;')

        # Replace SentencePiece '_' with a visible HTML space for clean rendering
        clean_token = clean_token.replace(' ', '&nbsp;')
        if not clean_token.strip('&nbsp;'):
            clean_token = '&nbsp;' # Handle fully empty/space tokens

        # Styling logic based on activation strength
        if act > 0:
            # Scale alpha up to 0.9 for vivid colors
            bg_color = f"rgba(255, 50, 50, {norm_act * 0.9})"
            # Use white text on dark red backgrounds for readability
            text_color = "white" if norm_act > 0.5 else "black"
            border = "1px solid rgba(255, 50, 50, 0.4)"
        else:
            # Grayed out style for zero activation
            bg_color = "#f0f0f0"
            text_color = "#aaa"
            border = "1px solid #e0e0e0"

        # Create individual token spans with hover popups
        span = f"""
        <span style="
            background-color: {bg_color};
            color: {text_color};
            border: {border};
            padding: 2px 6px;
            border-radius: 4px;
            cursor: help;
            transition: transform 0.1s;
        " title="Activation: {act:.4f}"
        onmouseover="this.style.transform='scale(1.1)'"
        onmouseout="this.style.transform='scale(1)'"
        >
            {clean_token}
        </span>
        """
        html_string += span

    html_string += """
        </div>
    </div>
    """

    # Render in Jupyter
    display(HTML(html_string))

# Run it exactly as before!
sample_text = "The capital of France is Paris, and the capital of Italy is Rome."
visualize_feature_activation(sample_text, feature_idx=42, model=model, transcoder=transcoder, tokenizer=tokenizer, layer_idx=15, device=device)