In [22]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Literal
import einops
import numpy as np
import torch as t
from jaxtyping import Float
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part31_superposition_and_saes"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

import part31_superposition_and_saes.utils as utils
import part31_superposition_and_saes.tests as tests
from plotly_utils import line, imshow

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
t.manual_seed(2)
MAIN = __name__ == "__main__"

In [23]:
def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))


@dataclass
class Config:
    # We optimize n_inst models in a single training loop to let us sweep over sparsity or importance
    # curves efficiently. You should treat the number of instances `n_inst` like a batch dimension, 
    # but one which is built into our training setup. Ignore the latter 3 arguments for now, they'll
    # return in later exercises.
    n_inst: int
    n_features: int = 5
    d_hidden: int = 2
    n_correlated_pairs: int = 0
    n_anticorrelated_pairs: int = 0
    feat_mag_distn: Literal["unif", "jump"] = "unif"


class Model(nn.Module):
    W: Float[Tensor, "inst d_hidden feats"]
    b_final: Float[Tensor, "inst feats"]

    # Our linear map (for a single instance) is x -> ReLU(W.T @ W @ x + b_final)

    def __init__(
        self,
        cfg: Config,
        feature_probability: float | Tensor = 0.01,
        importance: float | Tensor = 1.0,
        device=device,
    ):
        super(Model, self).__init__()
        self.cfg = cfg

        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to(
            (cfg.n_inst, cfg.n_features)
        )
        if isinstance(importance, float):
            importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W = nn.Parameter(
            nn.init.xavier_normal_(t.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features)))
        )
        self.b_final = nn.Parameter(t.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)


    def forward(
        self,
        features: Float[Tensor, "... inst feats"],
    ) -> Float[Tensor, "... inst feats"]:
        h = einops.einsum(self.W, features,
            "inst d_hidden feats,... inst feats ->... inst d_hidden")
        h1 = einops.einsum(self.W.transpose(1,2),h,
            "inst feats d_hidden,... inst d_hidden ->... inst feats")
        out = F.relu(h1 + self.b_final)
        return out


    def generate_batch(self, batch_size) -> Float[Tensor, "batch inst feats"]:
        """
        Generates a batch of data.
        """
        # You'll fill this in later

        data_size = (batch_size, self.cfg.n_inst, self.cfg.n_features)
        data = t.rand(data_size, device= self.W.device)
        pro = t.rand(data_size, device= self.W.device)

        return t.where(pro<self.feature_probability, data, 0.0)
        


    def calculate_loss(
        self,
        out: Float[Tensor, "batch inst feats"],
        batch: Float[Tensor, "batch inst feats"],
    ) -> Float[Tensor, ""]:
        """
        Calculates the loss for a given batch (as a scalar tensor), using this loss described in the
        Toy Models of Superposition paper:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        Note, `self.importance` is guaranteed to broadcast with the shape of `out` and `batch`.
        """
        # You'll fill this in later
        error = self.importance * ( (out-batch)**2 )
        loss = einops.reduce(error,
                             "batch inst feats->inst", "mean").sum()
        return loss


    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 50,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        """
        Optimizes the model using the given hyperparameters.
        """
        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:
            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # Optimize
            optimizer.zero_grad()
            batch = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item() / self.cfg.n_inst, lr=step_lr)


In [30]:
@dataclass
class SAEConfig:
    n_inst: int
    d_in: int
    d_sae: int
    l1_coeff: float = 0.2
    weight_normalize_eps: float = 1e-8
    tied_weights: bool = False
    architecture: Literal["standard", "gated"] = "standard"


class SAE(nn.Module):
    W_enc: Float[Tensor, "inst d_in d_sae"]
    _W_dec: Float[Tensor, "inst d_sae d_in"] | None
    b_enc: Float[Tensor, "inst d_sae"]
    b_dec: Float[Tensor, "inst d_in"]

    def __init__(self, cfg: SAEConfig, model: Model) -> None:
        super(SAE, self).__init__()

        assert cfg.d_in == model.cfg.d_hidden, "Model's hidden dim doesn't match SAE input dim"
        self.cfg = cfg
        self.model = model.requires_grad_(False)

        self.W_enc = nn.Parameter(
            nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_in, cfg.d_sae)))
        )
        self._W_dec = (
            None
            if self.cfg.tied_weights
            else nn.Parameter(nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_sae, cfg.d_in))))
        )
        self.b_enc = nn.Parameter(t.zeros(cfg.n_inst, cfg.d_sae))
        self.b_dec = nn.Parameter(t.zeros(cfg.n_inst, cfg.d_in))

        self.to(device)

    @property
    def W_dec(self) -> Float[Tensor, "inst d_sae d_in"]:
        return self._W_dec if self._W_dec is not None else self.W_enc.transpose(-1, -2)

    @property
    def W_dec_normalized(self) -> Float[Tensor, "inst d_sae d_in"]:
        """Returns decoder weights, normalized over the autoencoder input dimension."""
        # YOUR CODE HERE
        return self.W_dec / self.W_dec.norm(dim = -1, keepdim= True)

    def generate_batch(self, batch_size: int) -> Float[Tensor, "batch inst d_in"]:
        """
        Generates a batch of hidden activations from our model.
        """
        # YOUR CODE HERE
        data = self.model.generate_batch(batch_size= batch_size)
        h = einops.einsum(self.model.W, data,
            "inst d_hidden feats,... inst feats ->... inst d_hidden")
        
        return h

    def forward(
        self, h: Float[Tensor, "batch inst d_in"]
    ) -> tuple[
        dict[str, Float[Tensor, "batch inst"]],
        Float[Tensor, ""],
        Float[Tensor, "batch inst d_sae"],
        Float[Tensor, "batch inst d_in"],
    ]:
        """
        Forward pass on the autoencoder.

        Args:
            h: hidden layer activations of model

        Returns:
            loss_dict: dict of different loss function term values, for every (batch elem, instance)
            loss: scalar total loss (summed over instances & averaged over batch dim)
            acts: autoencoder feature activations
            h_reconstructed: reconstructed autoencoder input
        """
        # YOUR CODE HERE
        z = F.relu(self.b_enc + 
                   einops.einsum(
                       self.W_enc, h-self.b_dec,
                       "inst d_in d_sae, batch inst d_in -> batch inst d_sae"
                   ))
        h_ = self.b_dec + einops.einsum(self.W_dec, z,
                    "inst d_sae d_in, batch inst d_sae -> batch inst d_in")
        
        L_re = (h-h_).pow(2).mean(-1)
        L_sp = z.abs().sum(-1)

        loss_dict = {
            "L_reconstruction": L_re,
            "L_sparsity": L_sp,
        }
        loss = (L_re + L_sp* self.cfg.l1_coeff).mean(0).sum()

        return loss_dict, loss, z, h_

    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 50,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
        resample_method: Literal["simple", "advanced", None] = None,
        resample_freq: int = 2500,
        resample_window: int = 500,
        resample_scale: float = 0.5,
    ) -> dict[str, list]:
        """
        Optimizes the autoencoder using the given hyperparameters.

        Args:
            model:              we reconstruct features from model's hidden activations
            batch_size:         size of batches we pass through model & train autoencoder on
            steps:              number of optimization steps
            log_freq:           number of optimization steps between logging
            lr:                 learning rate
            lr_scale:           learning rate scaling function
            resample_method:    method for resampling dead latents
            resample_freq:      number of optimization steps between resampling dead latents
            resample_window:    number of steps needed for us to classify a neuron as dead
            resample_scale:     scale factor for resampled neurons

        Returns:
            data_log:               dictionary containing data we'll use for visualization
        """
        assert resample_window <= resample_freq

        optimizer = t.optim.Adam(list(self.parameters()), lr=lr, betas=(0.0, 0.999))
        frac_active_list = []
        progress_bar = tqdm(range(steps))

        # Create lists to store data we'll eventually be plotting
        data_log = {"steps": [], "W_enc": [], "W_dec": [], "frac_active": []}

        for step in progress_bar:
            # Resample dead latents
            if (resample_method is not None) and ((step + 1) % resample_freq == 0):
                frac_active_in_window = t.stack(frac_active_list[-resample_window:], dim=0)
                if resample_method == "simple":
                    self.resample_simple(frac_active_in_window, resample_scale)
                elif resample_method == "advanced":
                    self.resample_advanced(frac_active_in_window, resample_scale, batch_size)

            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # Get a batch of hidden activations from the model
            with t.inference_mode():
                h = self.generate_batch(batch_size)

            # Optimize
            loss_dict, loss, acts, _ = self.forward(h)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Normalize decoder weights by modifying them inplace (if not using tied weights)
            if not self.cfg.tied_weights:
                self.W_dec.data = self.W_dec_normalized

            # Calculate the mean sparsities over batch dim for each feature
            frac_active = (acts.abs() > 1e-8).float().mean(0)
            frac_active_list.append(frac_active)

            # Display progress bar, and append new values for plotting
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(
                    lr=step_lr,
                    frac_active=frac_active.mean().item(),
                    **{k: v.mean(0).sum().item() for k, v in loss_dict.items()},  # type: ignore
                )
                data_log["W_enc"].append(self.W_enc.detach().cpu().clone())
                data_log["W_dec"].append(self.W_dec.detach().cpu().clone())
                data_log["frac_active"].append(frac_active.detach().cpu().clone())
                data_log["steps"].append(step)

        return data_log

    @t.no_grad()
    def resample_simple(
        self,
        frac_active_in_window: Float[Tensor, "window inst d_sae"],
        resample_scale: float,
    ) -> None:
        """
        Resamples dead latents, by modifying the model's weights and biases inplace.

        Resampling method is:
            - For each dead neuron, generate a random vector of size (d_in,), and normalize these vectors
            - Set new values of W_dec and W_enc to be these normalized vectors, at each dead neuron
            - Set b_enc to be zero, at each dead neuron
        """
        # Get a tensor of dead latents
        dead_latents_mask = (frac_active_in_window < 1e-8).all(dim=0)  # [instances d_sae]
        n_dead = int(dead_latents_mask.int().sum().item())

        # Get our random replacement values of shape [n_dead d_in], and scale them
        replacement_values = t.randn((n_dead, self.cfg.d_in), device=self.W_enc.device)
        replacement_values_normed = replacement_values / (
            replacement_values.norm(dim=-1, keepdim=True) + self.cfg.weight_normalize_eps
        )

        # Change the corresponding values in W_enc, W_dec, and b_enc
        self.W_enc.data.transpose(-1, -2)[dead_latents_mask] = resample_scale * replacement_values_normed
        self.W_dec.data[dead_latents_mask] = replacement_values_normed
        self.b_enc.data[dead_latents_mask] = 0.0

    @t.no_grad()
    def resample_advanced(
        self,
        frac_active_in_window: Float[Tensor, "window inst d_sae"],
        resample_scale: float,
        batch_size: int,
    ) -> None:
        """
        Resamples latents that have been dead for 'dead_feature_window' steps, according to `frac_active`.

        Resampling method is:
            - Compute the L2 reconstruction loss produced from the hidden state vectors `h`
            - Randomly choose values of `h` with probability proportional to their reconstruction loss
            - Set new values of W_dec and W_enc to be these (centered and normalized) vectors, at each dead neuron
            - Set b_enc to be zero, at each dead neuron
        """


In [None]:
tests.test_sae_init(SAE)
tests.test_sae_W_dec_normalized(SAE)
tests.test_sae_generate_batch(SAE)
tests.test_sae_forward(SAE)

In [None]:
d_hidden = d_in = 2
n_features = d_sae = 5
n_inst = 8

cfg = Config(n_inst=n_inst, n_features=n_features, d_hidden=d_hidden)

model = Model(cfg=cfg, device=device)
model.optimize(steps=10_000)

sae = SAE(cfg=SAEConfig(n_inst=n_inst, d_in=d_in, d_sae=d_sae), model=model)

h = sae.generate_batch(500)
utils.plot_features_in_2d(model.W, title="Base model")
utils.plot_features_in_2d(
    einops.rearrange(h, "batch inst d_in -> inst d_in batch"),
    title="Hidden state representation of a random batch of data",
)

In [None]:
data_log = sae.optimize(steps=25_000)

utils.animate_features_in_2d(
    {
        "Encoder weights": t.stack(data_log["W_enc"]),
        "Decoder weights": t.stack(data_log["W_dec"]).transpose(-1, -2),
    },
    steps=data_log["steps"],
    filename="animation-training.html",
    title="SAE on toy model",
)

In [None]:
utils.frac_active_line_plot(
    frac_active=t.stack(data_log["frac_active"]),
    title="Probability of sae features being active during training",
    avg_window=10,
)

In [None]:
tests.test_resample_simple(SAE)

In [None]:
sae = SAE(cfg=SAEConfig(n_inst=n_inst, d_in=d_in, d_sae=d_sae), model=model)

data_log = sae.optimize(steps=25_000, resample_method="simple")

utils.animate_features_in_2d(
    {
        "Encoder weights": t.stack(data_log["W_enc"]),
        "Decoder weights": t.stack(data_log["W_dec"]).transpose(-1, -2),
    },
    steps=data_log["steps"],
    filename="animation-resampling.html",
    title="SAE on toy model with simple resampling",
)

In [None]:
with t.inference_mode():
    h_r = sae(h)[-1]

utils.animate_features_in_2d(
    {
        "h": einops.rearrange(h, "batch inst d_in -> inst d_in batch"),
        "h<sub>r</sub>": einops.rearrange(h_r, "batch inst d_in -> inst d_in batch"),
    },
    filename="animation-reconstructions.html",
    title="Hidden state vs reconstructions",
)

In [None]:
sae = SAE(cfg=SAEConfig(n_inst=n_inst, d_in=d_in, d_sae=d_sae+5), model=model)

data_log = sae.optimize(steps=25_000, resample_method="simple")

utils.animate_features_in_2d(
    {
        "Encoder weights": t.stack(data_log["W_enc"]),
        "Decoder weights": t.stack(data_log["W_dec"]).transpose(-1, -2),
    },
    steps=data_log["steps"],
    filename="animation-resampling.html",
    title="SAE on toy model with simple resampling",
)

In [None]:
cfg = Config(n_inst=8, n_features=4, d_hidden=2)

model = Model(cfg=cfg, device=device, feature_probability=0.025)

# Replace the model's weights with a custom-chosen non-uniform set of features
angles = 2 * t.pi * t.tensor([0.0, 0.25, 0.55, 0.70])
angles = angles + t.rand((cfg.n_inst, 1))
model.W.data = t.stack([t.cos(angles), t.sin(angles)], dim=1).to(device)

utils.plot_features_in_2d(
    model.W,
    title=f"Superposition: {cfg.n_features} features in 2D space (non-uniform)",
    subplot_titles=[f"Instance #{i}" for i in range(1, 1 + cfg.n_inst)],
)

In [None]:
sae = SAE(cfg=SAEConfig(n_inst=n_inst, d_in=d_in, d_sae=d_sae), model=model)

data_log = sae.optimize(steps=25_000, resample_method="simple")

utils.animate_features_in_2d(
    {
        "Encoder weights": t.stack(data_log["W_enc"]),
        "Decoder weights": t.stack(data_log["W_dec"]).transpose(-1, -2),
    },
    steps=data_log["steps"],
    filename="animation-resampling.html",
    title="SAE on toy model with simple resampling",
)