In [3]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Literal
import einops
import numpy as np
import torch as t
from jaxtyping import Float
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part31_superposition_and_saes"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

import part31_superposition_and_saes.utils as utils
import part31_superposition_and_saes.tests as tests
from plotly_utils import line, imshow

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)

MAIN = __name__ == "__main__"

In [25]:
def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))


@dataclass
class Config:
    # We optimize n_inst models in a single training loop to let us sweep over sparsity or importance
    # curves efficiently. You should treat the number of instances `n_inst` like a batch dimension, 
    # but one which is built into our training setup. Ignore the latter 3 arguments for now, they'll
    # return in later exercises.
    n_inst: int
    n_features: int = 5
    d_hidden: int = 2
    n_correlated_pairs: int = 0
    n_anticorrelated_pairs: int = 0
    feat_mag_distn: Literal["unif", "jump"] = "unif"


class Model(nn.Module):
    W: Float[Tensor, "inst d_hidden feats"]
    b_final: Float[Tensor, "inst feats"]

    # Our linear map (for a single instance) is x -> ReLU(W.T @ W @ x + b_final)

    def __init__(
        self,
        cfg: Config,
        feature_probability: float | Tensor = 0.01,
        importance: float | Tensor = 1.0,
        device=device,
    ):
        super(Model, self).__init__()
        self.cfg = cfg

        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to(
            (cfg.n_inst, cfg.n_features)
        )
        if isinstance(importance, float):
            importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W = nn.Parameter(
            nn.init.xavier_normal_(t.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features)))
        )
        self.b_final = nn.Parameter(t.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)


    def forward(
        self,
        features: Float[Tensor, "... inst feats"],
    ) -> Float[Tensor, "... inst feats"]:
        h = einops.einsum(self.W, features,
            "inst d_hidden feats,... inst feats ->... inst d_hidden")
        h1 = einops.einsum(self.W.transpose(1,2),h,
            "inst feats d_hidden,... inst d_hidden ->... inst feats")
        out = F.relu(h1 + self.b_final)
        return out


    def generate_correlated_features(
        self, batch_size: int, n_correlated_pairs: int
    ) -> Float[Tensor, "batch inst 2*n_correlated_pairs"]:
        """
        Generates a batch of correlated features. For each pair `batch[i, j, [2k, 2k+1]]`, one of
        them is non-zero if and only if the other is non-zero.

        This solution works by creating a boolean mask of shape [batch inst n_correlated_pairs]
        which represents whether the feature set is present, then repeating that mask across feature
        pairs.
        """
        assert t.all((self.feature_probability == self.feature_probability[:, [0]]))
        p = self.feature_probability[:, [0]]  # shape (n_inst, 1)

        feat_mag = t.rand(
            (batch_size, self.cfg.n_inst, 2 * n_correlated_pairs), device=self.W.device
        )
        feat_set_seeds = t.rand(
            (batch_size, self.cfg.n_inst, n_correlated_pairs), device=self.W.device
        )
        feat_set_is_present = feat_set_seeds <= p
        feat_is_present = einops.repeat(
            feat_set_is_present,
            "batch instances features -> batch instances (features pair)",
            pair=2,
        )
        return t.where(feat_is_present, feat_mag, 0.0)

    def generate_anticorrelated_features(
        self, batch_size: int, n_anticorrelated_pairs: int
    ) -> Float[Tensor, "batch inst 2*n_anticorrelated_pairs"]:
        """
        Generates a batch of anti-correlated features. For each pair `batch[i, j, [2k, 2k+1]]`, each
        of them can only be non-zero if the other one is zero.

        There are at least 2 possible ways you could do this:
            (1) Exactly one of batch[i, j, [2k, 2k+1]] is present with probability 2p, and in this
                event we choose which of these two is present randomly.
            (2) batch[i, j, 2k] is present with probability p, and batch[i, j, 2k+1] is present with
                probability p / (1 - p) if and only if batch[i, j, 2k] is present.

        This solution uses (2), but both are valid.
        """
        assert t.all((self.feature_probability == self.feature_probability[:, [0]]))
        p = self.feature_probability[:, [0]]  # shape (n_inst, 1)

        assert p.max().item() <= 0.5, "For anticorrelated features, must have 2p < 1"

        feat_mag = t.rand(
            (batch_size, self.cfg.n_inst, 2 * n_anticorrelated_pairs), device=self.W.device
        )
        even_feat_seeds, odd_feat_seeds = t.rand(
            (2, batch_size, self.cfg.n_inst, n_anticorrelated_pairs),
            device=self.W.device,
        )
        even_feat_is_present = even_feat_seeds <= p
        odd_feat_is_present = (even_feat_seeds > p) & (odd_feat_seeds <= p / (1 - p))
        feat_is_present = einops.rearrange(
            t.stack([even_feat_is_present, odd_feat_is_present], dim=0),
            "pair batch instances features -> batch instances (features pair)",
        )
        return t.where(feat_is_present, feat_mag, 0.0)

    def generate_uncorrelated_features(self, batch_size: int, n_uncorrelated: int) -> Tensor:
        """
        Generates a batch of uncorrelated features.
        """
        if n_uncorrelated == self.cfg.n_features:
            p = self.feature_probability
        else:
            assert t.all((self.feature_probability == self.feature_probability[:, [0]]))
            p = self.feature_probability[:, [0]]  # shape (n_inst, 1)

        feat_mag = t.rand((batch_size, self.cfg.n_inst, n_uncorrelated), device=self.W.device)
        feat_seeds = t.rand((batch_size, self.cfg.n_inst, n_uncorrelated), device=self.W.device)
        return t.where(feat_seeds <= p, feat_mag, 0.0)
        
    def generate_batch(self, batch_size) -> Float[Tensor, "batch inst feats"]:
        """
        Generates a batch of data, with optional correlated & anticorrelated features.
        """
        n_corr_pairs = self.cfg.n_correlated_pairs
        n_anti_pairs = self.cfg.n_anticorrelated_pairs
        n_uncorr = self.cfg.n_features - 2 * n_corr_pairs - 2 * n_anti_pairs

        data = []
        if n_corr_pairs > 0:
            data.append(self.generate_correlated_features(batch_size, n_corr_pairs))
        if n_anti_pairs > 0:
            data.append(self.generate_anticorrelated_features(batch_size, n_anti_pairs))
        if n_uncorr > 0:
            data.append(self.generate_uncorrelated_features(batch_size, n_uncorr))
        batch = t.cat(data, dim=-1)
        return batch


    def calculate_loss(
        self,
        out: Float[Tensor, "batch inst feats"],
        batch: Float[Tensor, "batch inst feats"],
    ) -> Float[Tensor, ""]:
        """
        Calculates the loss for a given batch (as a scalar tensor), using this loss described in the
        Toy Models of Superposition paper:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        Note, `self.importance` is guaranteed to broadcast with the shape of `out` and `batch`.
        """
        # You'll fill this in later
        error = self.importance * ( (out-batch)**2 )
        loss = einops.reduce(error,
                             "batch inst feats->inst", "mean").sum()
        return loss


    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 50,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        """
        Optimizes the model using the given hyperparameters.
        """
        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:
            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # Optimize
            optimizer.zero_grad()
            batch = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item() / self.cfg.n_inst, lr=step_lr)

In [None]:
cfg = Config(n_inst=30, n_features=4, d_hidden=2, n_correlated_pairs=1, n_anticorrelated_pairs=1)


feature_probability = 10 ** -t.linspace(0.5, 1, cfg.n_inst).to(device)

model = Model(cfg=cfg, device=device, feature_probability=feature_probability[:, None])

# Generate a batch of 4 features: first 2 are correlated, second 2 are anticorrelated
batch = model.generate_batch(batch_size=100_000)
corr0, corr1, anticorr0, anticorr1 = batch.unbind(dim=-1)

assert ((corr0 != 0) == (corr1 != 0)).all(), "Correlated features should be active together"
assert (
    ((corr0 != 0).float().mean(0) - feature_probability).abs().mean() < 0.002
), "Each correlated feature should be active with probability `feature_probability`"

assert (
    (anticorr0 != 0) & (anticorr1 != 0)
).int().sum().item() == 0, "Anticorrelated features should never be active together"
assert (
    ((anticorr0 != 0).float().mean(0) - feature_probability).abs().mean() < 0.002
), "Each anticorrelated feature should be active with probability `feature_probability`"

In [None]:
# Generate a batch of 4 features: first 2 are correlated, second 2 are anticorrelated
batch = model.generate_batch(batch_size=1)
correlated_feature_batch, anticorrelated_feature_batch = batch.split(2, dim=-1)

# Plot correlated features
utils.plot_correlated_features(
    correlated_feature_batch, title="Correlated feature pairs: should always co-occur"
)
utils.plot_correlated_features(
    anticorrelated_feature_batch, title="Anti-correlated feature pairs: should never co-occur"
)

In [None]:
cfg = Config(n_inst=5, n_features=4, d_hidden=2, n_correlated_pairs=0,n_anticorrelated_pairs=0)

# All same importance, very low feature probabilities (ranging from 5% down to 0.25%)
importance = t.ones(cfg.n_features, dtype=t.float, device=device)
feature_probability = 400 ** -t.linspace(0.5, 1, cfg.n_inst)

model = Model(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability[:, None],
)
model.optimize(steps=10_000)

utils.plot_features_in_2d(
    model.W,
    colors=["blue"] * 2 + ["limegreen"] * 2,
    title="Correlated feature sets are represented in local orthogonal bases",
    subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability],
)