In [3]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Literal
import einops
import numpy as np
import torch as t
from jaxtyping import Float
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part31_superposition_and_saes"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

import part31_superposition_and_saes.utils as utils
import part31_superposition_and_saes.tests as tests
from plotly_utils import line, imshow

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
t.manual_seed(2)
MAIN = __name__ == "__main__"

In [4]:
def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))


@dataclass
class Config:
    # We optimize n_inst models in a single training loop to let us sweep over sparsity or importance
    # curves efficiently. You should treat the number of instances `n_inst` like a batch dimension, 
    # but one which is built into our training setup. Ignore the latter 3 arguments for now, they'll
    # return in later exercises.
    n_inst: int
    n_features: int = 5
    d_hidden: int = 2
    n_correlated_pairs: int = 0
    n_anticorrelated_pairs: int = 0
    feat_mag_distn: Literal["unif", "jump"] = "unif"


class Model(nn.Module):
    W: Float[Tensor, "inst d_hidden feats"]
    b_final: Float[Tensor, "inst feats"]

    # Our linear map (for a single instance) is x -> ReLU(W.T @ W @ x + b_final)

    def __init__(
        self,
        cfg: Config,
        feature_probability: float | Tensor = 0.01,
        importance: float | Tensor = 1.0,
        device=device,
    ):
        super(Model, self).__init__()
        self.cfg = cfg

        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to(
            (cfg.n_inst, cfg.n_features)
        )
        if isinstance(importance, float):
            importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W = nn.Parameter(
            nn.init.xavier_normal_(t.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features)))
        )
        self.b_final = nn.Parameter(t.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)


    def forward(
        self,
        features: Float[Tensor, "... inst feats"],
    ) -> Float[Tensor, "... inst feats"]:
        h = einops.einsum(self.W, features,
            "inst d_hidden feats,... inst feats ->... inst d_hidden")
        h1 = einops.einsum(self.W.transpose(1,2),h,
            "inst feats d_hidden,... inst d_hidden ->... inst feats")
        out = F.relu(h1 + self.b_final)
        return out


    def generate_batch(self, batch_size) -> Float[Tensor, "batch inst feats"]:
        """
        Generates a batch of data.
        """
        # You'll fill this in later

        data_size = (batch_size, self.cfg.n_inst, self.cfg.n_features)
        data = t.rand(data_size, device= self.W.device)
        pro = t.rand(data_size, device= self.W.device)

        return t.where(pro<self.feature_probability, data, 0.0)
        


    def calculate_loss(
        self,
        out: Float[Tensor, "batch inst feats"],
        batch: Float[Tensor, "batch inst feats"],
    ) -> Float[Tensor, ""]:
        """
        Calculates the loss for a given batch (as a scalar tensor), using this loss described in the
        Toy Models of Superposition paper:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        Note, `self.importance` is guaranteed to broadcast with the shape of `out` and `batch`.
        """
        # You'll fill this in later
        error = self.importance * ( (out-batch)**2 )
        loss = einops.reduce(error,
                             "batch inst feats->inst", "mean").sum()
        return loss


    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 50,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        """
        Optimizes the model using the given hyperparameters.
        """
        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:
            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # Optimize
            optimizer.zero_grad()
            batch = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item() / self.cfg.n_inst, lr=step_lr)

In [None]:
class NeuronModel(Model):
    def forward(
        self,
        features: Float[Tensor, "... instances features"]
    ) -> Float[Tensor, "... instances features"]:
        h = einops.einsum(self.W, features,
            "inst d_hidden feats,... inst feats ->... inst d_hidden")
        h = F.relu(h)
        h1 = einops.einsum(self.W.transpose(1,2),h,
            "inst feats d_hidden,... inst d_hidden ->... inst feats")
        out = F.relu(h1 + self.b_final)
        return out
        raise NotImplementedError()


tests.test_neuron_model(NeuronModel)

In [None]:
cfg = Config(n_inst=7, n_features=10, d_hidden=5)

importance = 0.75 ** t.arange(1, 1 + cfg.n_features)
feature_probability = t.tensor([0.75, 0.35, 0.15, 0.1, 0.06, 0.02, 0.01])

model = NeuronModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability[:, None],
)
model.optimize(steps=10_000)

utils.plot_features_in_Nd(
    model.W,
    height=600,
    width=1000,
    subplot_titles=[f"1 - S = {i:.2f}" for i in feature_probability.squeeze()],
    title=f"Neuron model: {cfg.n_features=}, {cfg.d_hidden=}, I<sub>i</sub> = 0.75<sup>i</sup>",
    neuron_plot=True,
)

In [None]:
class NeuronComputationModel(Model):
    W1: Float[Tensor, "inst d_hidden feats"]
    W2: Float[Tensor, "inst feats d_hidden"]
    b_final: Float[Tensor, "inst feats"]

    def __init__(
        self,
        cfg: Config,
        feature_probability: float | Tensor = 1.0,
        importance: float | Tensor = 1.0,
        device=device,
    ):
        super(Model, self).__init__()
        self.cfg = cfg

        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to(
            (cfg.n_inst, cfg.n_features)
        )
        if isinstance(importance, float):
            importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W1 = nn.Parameter(nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features))))
        self.W2 = nn.Parameter(nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.n_features, cfg.d_hidden))))
        self.b_final = nn.Parameter(t.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)


    def forward(self, features: Float[Tensor, "... inst feats"]) -> Float[Tensor, "... inst feats"]:
        # YOUR CODE HERE

        h1 = F.relu(einops.einsum(self.W1, features, 
                "inst d_hidden feats,  ... inst feats -> ... inst d_hidden"))
        
        out = F.relu(einops.einsum(self.W2, h1,
                "inst feats d_hidden, ... inst d_hidden -> ... inst feats")
                +self.b_final)
        return out
        raise NotImplementedError()


    def generate_batch(self, batch_size) -> Tensor:
        # YOUR CODE HERE
        data_size = (batch_size, self.cfg.n_inst, self.cfg.n_features)
        data = t.rand(data_size, device= self.W1.device)*2 - 1
        pro = t.rand(data_size, device= self.W1.device)

        return t.where(pro<self.feature_probability, data, 0.0)
        raise NotImplementedError()


    def calculate_loss(
        self,
        out: Float[Tensor, "batch instances features"],
        batch: Float[Tensor, "batch instances features"],
    ) -> Float[Tensor, ""]:
        # YOUR CODE HERE
        error = self.importance * ( (out-batch.abs())**2 )
        loss = einops.reduce(error,
                             "batch inst feats->inst", "mean").sum()
        return loss
        raise NotImplementedError()


tests.test_neuron_computation_model(NeuronComputationModel)

In [None]:
cfg = Config(n_inst=7, n_features=100, d_hidden=40)

importance = 0.8 ** t.arange(1, 1 + cfg.n_features)
feature_probability = t.tensor([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001])

model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability[:, None],
)
model.optimize(steps=10_000)

utils.plot_features_in_Nd(
    model.W1,
    height=800,
    width=1400,
    neuron_plot=True,
    subplot_titles=[f"1 - S = {i:.3f}<br>" for i in feature_probability.squeeze()],
    title=f"Neuron computation model: {cfg.n_features=}, {cfg.d_hidden=}, I<sub>i</sub�> = 0.75<sup>i</sup>",
)

In [None]:
cfg = Config(n_inst=5, n_features=10, d_hidden=10)

importance = 0.8 ** t.arange(1, 1 + cfg.n_features)
feature_probability = 0.5

model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability,
)
model.optimize(steps=10_000)

utils.plot_features_in_Nd_discrete(
    W1=model.W1,
    W2=model.W2,
    title="Neuron computation model (colored discretely, by feature)",
    legend_names=[
        f"I<sub>{i}</sub> = {importance.squeeze()[i]:.3f}" for i in range(cfg.n_features)
    ],
)