In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir("../..")

In [2]:
from dataclasses import dataclass
from typing import Callable, Literal

import einops
import numpy as np
import torch
from tqdm import tqdm
from torch import nn, Tensor
from torch.nn import functional as F

from mech_interp import utils

In [3]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")

Using device: mps


# 1. Toy model setup

Models may or may not have a privileged basis. A privileged basis refers to a special or **preferred coordinate system or direction** in the representation space of a model — usually one defined by the model’s architecture, such as individual neurons in a hidden layer.

Models without a privileged basis are elegant, and can be an interesting analogue for certain neural network representations which don't have a privileged basis – word embeddings, or the transformer residual stream. But of primary interest is the understanding of neural network representations that have neurons which do impose a privileged basis, such as transformer MLP layers or convolutional network neurons.

The simplest toy model with a privileged basis is a non-privileged basis model with an activation function, which allows for the representation of hidden layers with neurons, such as the transformer MLP layer. Based on the previous model, it can be represented by adding a ReLU to the hidden layer:

\begin{align*}
h &= \text{ReLU}(Wx) \\
x' &= \text{ReLU}(W^T h + b)
\end{align*}


In [4]:
@dataclass
class ToyModelConfig:
    # `n_inst` models are optimized at once in a single training loop, so they can act as a batch dimension
    # for learnable parms: each weight/bias will have `n_inst` stacked instances along the zeroth dimension
    # allowing for efficient analysis of sparsity and importance curves
    n_inst: int
    n_features: int = 5
    d_hidden: int = 2
    n_correlated_pairs: int = 0
    n_anticorrelated_pairs: int = 0
    feat_mag_distn: Literal['unif', 'normal'] = 'unif'

In [5]:
def linear_lr(step, steps):
    return 1 - (step / steps)


def constant_lr(*_):
    """Constant learning rate"""
    return 1.0


def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))

In [6]:
class ToyModel(nn.Module):
    """
    A toy model for demonstrating the setup of the paper "Toy Models of Superposition".

    Important concepts:
        `feature_probability`: used to generate training data. Default is `None`, which means
            p = 1 (no sparsity).
        `importance`: used in the loss function. Default is `None`, which results in uniform
            importance across all features.
    """
    W: Tensor
    b_final: Tensor

    def __init__(
        self,
        cfg: ToyModelConfig,
        feature_probability: Tensor = None,
        importance: Tensor = None,
        device=device,
    ):
        super().__init__()

        if not isinstance(feature_probability, Tensor):
            feature_probability = torch.tensor(feature_probability)

        if not isinstance(importance, Tensor):
            importance = torch.tensor(importance)

        self.cfg = cfg
        self.feature_probability = feature_probability.to(device).broadcast_to((cfg.n_inst, cfg.n_features))
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W = nn.Parameter(nn.init.xavier_normal_(torch.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features))))
        self.b_final = nn.Parameter(torch.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)

    def generate_batch(self, batch_size: int) -> Tensor:
        """
        Generates a batch of training data with the specified batch size.

        `feat_mag`: random magnitudes for each feature
        `feat_seeds`: random thresholds to decide which features are activated or zeroed out.

        Returns:
            For each feature, either set it to a random magnitude or to 0. 
        """
        batch_shape = (batch_size, self.cfg.n_inst, self.cfg.n_features)
        feat_mag = torch.rand(batch_shape, device=self.W.device)
        feat_seeds = torch.rand(batch_shape, device=self.W.device)
        return torch.where(feat_seeds <= self.feature_probability, feat_mag, 0.0)

    def forward(self, features: Tensor) -> Tensor:
        h = F.relu(einops.einsum(features, self.W, "... inst feats, inst hidden feats -> ... inst hidden"))
        out = einops.einsum(h, self.W, "... inst hidden, inst hidden feats -> ... inst feats")
        return F.relu(out + self.b_final)

    def calculate_loss(self, out: Tensor, batch: Tensor) -> Tensor:
        """
        Calculates the loss for a given batch using the loss described in:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        The formula used is:

        .. math::

            L = \\frac{1}{BF} \\sum_{x} \\sum_{i} I_i (x_i - x'_i)^2

        Where:
        - B is the batch size
        - F is the number of features
        - :math:`x_i` are the inputs and :math:`x'_i` are the outputs
        - :math:`I_i` is the importance of feature i
        - :math:`\\sum_{i}` is the sum over all features
        - :math:`\\sum_{x}` is the sum over all elements in the batch
        """
        error = self.importance * ((batch - out) ** 2)
        loss = einops.reduce(error, "batch inst feats -> inst", "mean").sum()
        return loss

    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 5_000,
        log_freq: int = 50,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:
            # update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # optimize
            optimizer.zero_grad()
            batch = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item() / self.cfg.n_inst, lr=step_lr)

In [7]:
cfg = ToyModelConfig(n_inst=7, n_features=10, d_hidden=5)

importance = 0.75 ** torch.arange(1, 1 + cfg.n_features)
print("Importance check: ", importance)

feature_probability = torch.tensor([0.75, 0.35, 0.15, 0.1, 0.06, 0.02, 0.01])
print("Feature probability check: ", feature_probability)

Importance check:  tensor([0.7500, 0.5625, 0.4219, 0.3164, 0.2373, 0.1780, 0.1335, 0.1001, 0.0751,
        0.0563])
Feature probability check:  tensor([0.7500, 0.3500, 0.1500, 0.1000, 0.0600, 0.0200, 0.0100])


In [8]:
model = ToyModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability[:, None]
)

model.optimize(steps=10_000)

100%|██████████| 10000/10000 [00:23<00:00, 428.88it/s, loss=0.00212, lr=0.001]


**Monosemanticity / Polysemanticity and how this changes with increasing sparsity**

- **top row plots**: 
  - columns are hidden dimensions (neurons), rows are features

- **bottom row plots**:
  - columns are hidden dimensions (neurons), rows are the exposures of the features to that particular neuron
  - each feature is colored differently based on its interference with other features
    - dark blue: the feature is orthogonal to all other features
    - lighter colors: the sum of squared dot products with other features is large

1. **Low sparsity ($\text{feature prob} \approx 1.0$)**:
  - pure monosemanticity: each feature is represented faithfully by a single neuron or not at all (no superposition)
  - leftmost heatmap: each feature has a corresponding neuron which detects that particular feature, and no other
  - leftmost bar chart: each neuron has just one feature exposed to it

2. **Medium sparsity ($\text{feature prob} \approx 0.5$)**:
  - middle heatmaps: some neurons are monosemantic, but others are polysemantic
  - middle bar charts: some neurons start to become polysemantic, with exposures to multiple features

3. **High sparsity ($\text{feature prob} \approx 0.0$)**:
  - pure polysemanticity: all neurons are polysemantic and all features are represented in some capacity
  - neurons cannot be orthogonal to each other, given there are more features than neurons
  - rightmost heatmap: all features are represented by all neurons, but the features are not orthogonal to each other
  - rightmost bar chart: all neurons are polysemantic, with exposures to all features

**Conclusion**: while the features are always monosemantic, there are *neuron-level phase changes from monosemanticity to polysemancity* as sparsity increases.

In [9]:
utils.plot_features_in_Nd(
    model.W,
    height=600,
    width=1000,
    title=f"Neuron model: {cfg.n_features=}, {cfg.d_hidden=}, I<sub>i</sub> = 0.75<sup>i</sup>",
    subplot_titles=[f"1 - S = {i:.2f}" for i in feature_probability.squeeze()],
    neuron_plot=True,
)

# 2. Computation in superposition

The model above does not benefit from the ReLU hidden layer, because its underlying operation is a linear function, since it is trying to reconstruct its own input (i.e. the identity). This may lead to odd behavior, such as the model learning biases which shift all the neurons into a positive regime so they behave linearly. On the other hand, the MLP layer in a transformer performs nonlinear computations on information. 

It is designed for **bottleneck (or representational) superposition**, when 

To illustrate this, the model below is designed for **neuron (or computational) superposition**. It takes an input $x$ and outputs $|x|$, the absolute value of $x$, equivalent to the nonlinear function $\text{ReLU}(x)+\text{ReLU}(−x)$. Also, the data $x$ is now sampled from the range $[-1, 1]$ rather than $[0, 1]$, because calculating the absolute value of $x \in \{0, 1\}$ would be equivalent to the identity function, which is trivial.

Note that the loss function now computes the importance-weighted $L_2$ error between $|x|$ and $x'$, instead of computing the error between the input $x$ and output $x'$.

In [10]:
class NeuronComputationModel(ToyModel):
    W1: Tensor
    W2: Tensor
    b_final: Tensor

    def __init__(
        self,
        cfg: ToyModelConfig,
        feature_probability: Tensor = None,
        importance: Tensor = None,
        device=device,
    ):
        super(ToyModel, self).__init__()

        if not isinstance(feature_probability, Tensor):
            feature_probability = torch.tensor(feature_probability)

        if not isinstance(importance, Tensor):
            importance = torch.tensor(importance)

        self.cfg = cfg
        self.feature_probability = feature_probability.to(device).broadcast_to((cfg.n_inst, cfg.n_features))
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W1 = nn.Parameter(nn.init.kaiming_uniform_(torch.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features))))
        self.W2 = nn.Parameter(nn.init.kaiming_uniform_(torch.empty((cfg.n_inst, cfg.n_features, cfg.d_hidden))))
        self.b_final = nn.Parameter(torch.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)

    def generate_batch(self, batch_size: int) -> Tensor:
        batch_shape = (batch_size, self.cfg.n_inst, self.cfg.n_features)
        feat_mag = 2 * torch.rand(batch_shape, device=self.W1.device) - 1
        feat_seed = torch.rand(batch_shape, device=self.W1.device)
        return torch.where(feat_seed <= self.feature_probability, feat_mag, 0.0)
    
    def forward(self, features: Tensor) -> Tensor:
        h = F.relu(einops.einsum(features, self.W1, "... inst feats, inst hidden feats -> ... inst hidden"))
        out = einops.einsum(h, self.W2, "... inst hidden, inst feats hidden -> ... inst feats")
        return F.relu(out + self.b_final)

    def calculate_loss(self, out: Tensor, batch: Tensor) -> Tensor:
        error = self.importance * ((batch.abs() - out) ** 2)
        loss = einops.reduce(error, "batch inst feats -> inst", "mean").sum()
        return loss

In [11]:
cfg = ToyModelConfig(n_inst=7, n_features=100, d_hidden=40)

importance = 0.8 ** torch.arange(1, 1 + cfg.n_features)
print("Importance check: ", importance.shape)

feature_probability = torch.tensor([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001])
print("Feature probability check: ", feature_probability)

Importance check:  torch.Size([100])
Feature probability check:  tensor([1.0000, 0.3000, 0.1000, 0.0300, 0.0100, 0.0030, 0.0010])


In [12]:
model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability[:, None],
)
model.optimize()

100%|██████████| 5000/5000 [00:20<00:00, 246.44it/s, loss=7.73e-5, lr=0.001] 


Patterns below are similar: with very low sparsity, most/all neurons are monosemantic, but more polysemantic neurons start to appear as sparsity increases until all neurons are polysemantic.

Another interesting observation: in the monosemantic (or mostly monosemantic) cases, for any given feature there will be some neurons which have positive exposures to that feature and others with negative exposure. This is because some neurons are representing the value $\text{ReLU}(x_i)$ and others are representing the value of $\text{ReLU}(−x_i)$, both required to compute the absolute value.

In [13]:
utils.plot_features_in_Nd(
    model.W1,
    height=800,
    width=1600,
    title=f"Neuron computation model: n_features = {cfg.n_features}, d_hidden = {cfg.d_hidden}, I<sub>i</sub> = 0.75<sup>i</sup>",
    subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability.squeeze()],
    neuron_plot=True,
)

To further confirm that this is happening, the values in the bar chart can be colored discretely by feature, rather than continuously by the polysemanticity of that feature. A feature probability of 50% will be used for this visualisation, which is high enough to make sure each neuron is monosemantic. The input weights $W_1$ form pairs of antipodal neurons (i.e. ones with positive / negative exposures to that feature direction), but both of these neurons have positive output weights $W_2$ for that feature.

In [14]:
cfg = ToyModelConfig(n_inst=6, n_features=20, d_hidden=10)

importance = 0.8 ** torch.arange(1, 1 + cfg.n_features)
feature_probability = 0.5

In [15]:
model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability,
)
model.optimize()

100%|██████████| 5000/5000 [00:14<00:00, 345.02it/s, loss=0.00664, lr=0.001]


In [16]:
utils.plot_features_in_Nd_discrete(
    W1=model.W1,
    W2=model.W2,
    title="Neuron computation model (colored discretely, by feature)",
    legend_names=[f"I<sub>{i}</sub> = {importance.squeeze()[i]:.3f}" for i in range(cfg.n_features)],
)

# 3. The asymmetric superposition motif

In Anthropic's `Toy models of superposition` paper, the authors describe this topic at length.

When sparsity is increased and superposition is introduced, monosemantic neurons not necessarily calculate either $\text{ReLU}(x_i)$ or $\text{ReLU}(−x_i)$ for some feature $i$. Instead, **asymmetric superposition** may occur, where a single neuron detects two different features $i$ and $j$, and stores these features with different magnitudes. Assume $W_1$ and $W_2$ have flipped magnitudes, so that $W_1$ has a large weight for feature $i$, while $W_2$ has a large weight for feature $j$.

When $i$ is present and $j$ is not, there's no problem, because the output for feature $i$ is `large * small` (correct size) and for $j$ is `small * small` (near zero). But when $j$ is present and $i$ is not, the output for feature $j$ is `small * large` (correct size) and for $i$ is `large * large` (much larger than it should be). In particular, this is bad when the sign of output for $i$ is positive. The model fixes this by repurposing another neuron to correct for the case when $j$ is present and $i$ is not, by taking advantage of the fact that the model has a ReLU at the very end, so it doesn't matter if output for a feature is very large and negative, because the loss will be truncated at zero.

In [17]:
cfg = ToyModelConfig(n_inst=6, n_features=10, d_hidden=10)

importance = 0.8 ** torch.arange(1, 1 + cfg.n_features)
feature_probability = 0.35  # slightly lower feature probability, to encourage a small degree of superposition

In [18]:
model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability,
)
model.optimize()

100%|██████████| 5000/5000 [00:14<00:00, 355.51it/s, loss=0.00795, lr=0.001]


In [19]:
utils.plot_features_in_Nd_discrete(
    W1=model.W1,
    W2=model.W2,
    title="Neuron computation model (colored discretely, by feature)",
    legend_names=[f"I<sub>{i}</sub> = {importance.squeeze()[i]:.3f}" for i in range(cfg.n_features)],
)

# Sources

1. [Ground truth - Toy models of superposition & Sparse Autoencoders](https://arena-chapter1-transformer-interp.streamlit.app/[1.3.1]_Toy_Models_of_Superposition_&_SAEs)
2. [Toy models of superposition, by Chris Olah, Dario Amodei, et. al.](https://transformer-circuits.pub/2022/toy_model/index.html)