## Introduction

This notebook provides reference implementations for [JumpReLU SAEs](https://arxiv.org/abs/2407.14435) in JAX and PyTorch, expanding on the pseudo-code provided in the paper. Specifically, we include:

* Implementations of the `jumprelu` and `step` functions with custom backward passes;
* Implementations of the SAE forward pass and L0-based loss function;
* Training loop implementations that optionally normalise the norms of the decoder matrix.

We don't implement some features used in the training setup described in the paper (e.g. learning rate and sparsity coefficient $\lambda$ warm-up) which are reasonably easy to add on if desired.

The notebook also provides comprehensive tests to check that the Jax and PyTorch implementations are consistent. This includes an end-to-end training test where we train SAEs (on synthetic data, using identical initialisations) using both the Jax and PyTorch implementations and check we get the similar parameters after three steps. You may find these tests useful for testing other implementations of JumpReLU SAEs against these reference implementations to confirm consistency.

You should be able to run this notebook on a CPU runtime.

## Setup

This section imports modules needed by the rest of the notebook, sets some constants (mainly hyperparameters) and generates some synthetic data and initialises some SAE parameters that we use later for testing.

In [None]:
# @title Imports
import dataclasses
import functools
import itertools

import chex
import numpy as np
import jax
import jax.numpy as jnp
import optax
import plotly.express as px
import torch
from torch import nn

jax.config.update("jax_enable_x64", True)

In [None]:
# @title Hyperparameters and constants

NUM_STEPS = 3
BATCH_SIZE = 1024
ACTIVATIONS_SIZE = 16
SAE_WIDTH = 128
THRESHOLD_INIT = 0.001
# We use a higher bandwidth than in the paper to ensure a non-zero gradient
# to the threshold at every step (since we'll only be taking three steps)
BANDWIDTH = 0.1
FIX_DECODER_NORMS = True
LEARNING_RATE = 0.001  # Note this is not the learning rate in the paper
ADAM_B1 = 0.0
DATA_SEED = 9328302
PARAMS_SEED = 24396

In [None]:
# @title Create some synthetic data for testing

rng = np.random.default_rng(DATA_SEED)
dataset = rng.normal(
    size=(NUM_STEPS, BATCH_SIZE, ACTIVATIONS_SIZE)
) / np.sqrt(ACTIVATIONS_SIZE)

In [None]:
# @title Choose random SAE weights for testing

# We choose an initialization that is useful for testing. Specifically
# this means we initialize the biases and threshold to non-zero values
# and that we don't set the encoder weights to the transpose of the decoder
# (since they won't in general during training).
rng = np.random.default_rng(PARAMS_SEED)
W_dec = (rng.uniform(size=(SAE_WIDTH, ACTIVATIONS_SIZE)) - 0.5)
W_dec /= np.linalg.norm(W_dec, axis=-1, keepdims=True)
W_enc = (rng.uniform(size=(ACTIVATIONS_SIZE, SAE_WIDTH)) - 0.5)
b_enc = (rng.uniform(size=(SAE_WIDTH,)) - 0.5) * 0.1
b_dec = (rng.uniform(size=(ACTIVATIONS_SIZE,)) - 0.5) * 0.1
threshold = 0.15 * (rng.uniform(size=(SAE_WIDTH,))) * 0.1

## JAX implementation

Although a reference implementation, this JAX implementation can easily be made performant (even for wide SAEs on multiple devices) by simply applying Megatron sharding to the parameters before training (and sharding batches along the data axis). As explained in the [Gemma Scope](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) report, the main challenge with training SAEs efficiently is designing a dataloader that has sufficient throughput to not be the bottleneck (which is out of the scope of this notebook).

In [None]:
# @title Defining the `Params` dataclass

@chex.dataclass
class Params:
    W_enc: jax.Array
    b_enc: jax.Array
    W_dec: jax.Array
    b_dec: jax.Array
    log_threshold: jax.Array

params_init = Params(
    W_enc=W_enc,
    b_enc=b_enc,
    W_dec=W_dec,
    b_dec=b_dec,
    log_threshold=np.log(threshold),
)

In [None]:
# @title STEs, forward pass and loss function (taken from the paper)

def rectangle(x):
  return ((x > -0.5) & (x < 0.5)).astype(x.dtype)


### Implementation of step function with custom backward

@jax.custom_vjp
def step(x, threshold):
  return (x > threshold).astype(x.dtype)


def step_fwd(x, threshold):
  out = step(x, threshold)
  cache = x, threshold  # Saved for use in the backward pass
  return out, cache


def step_bwd(cache, output_grad):
  x, threshold = cache
  x_grad = 0.0 * output_grad  # We don't apply STE to x input
  threshold_grad = jnp.sum(
      -(1.0 / BANDWIDTH) * rectangle((x - threshold) / BANDWIDTH) * output_grad,
      axis=0,
  )
  return x_grad, threshold_grad


step.defvjp(step_fwd, step_bwd)


### Implementation of JumpReLU with custom backward for threshold

@jax.custom_vjp
def jumprelu(x, threshold):
  return x * (x > threshold)


def jumprelu_fwd(x, threshold):
  out = jumprelu(x, threshold)
  cache = x, threshold  # Saved for use in the backward pass
  return out, cache


def jumprelu_bwd(cache, output_grad):
  x, threshold = cache
  x_grad = (x > threshold) * output_grad  # We don't apply STE to x input
  threshold_grad = jnp.sum(
      -(threshold / BANDWIDTH)
      * rectangle((x - threshold) / BANDWIDTH)
      * output_grad,
      axis=0,
  )
  return x_grad, threshold_grad


jumprelu.defvjp(jumprelu_fwd, jumprelu_bwd)


### Implementation of JumpReLU SAE forward pass and loss functions

def sae(params, x, use_pre_enc_bias):
  # Optionally, apply pre-encoder bias
  if use_pre_enc_bias:
    x = x - params.b_dec

  pre_activations = x @ params.W_enc + params.b_enc
  threshold = jnp.exp(params.log_threshold)
  feature_magnitudes = jumprelu(pre_activations, threshold)

  # Decoder
  x_reconstructed = feature_magnitudes @ params.W_dec + params.b_dec

  # Also return pre_activations, needed to compute sparsity loss
  return x_reconstructed, pre_activations


### Implementation of JumpReLU loss

def loss(params, x, sparsity_coefficient, use_pre_enc_bias):
  x_reconstructed, pre_activations = sae(params, x, use_pre_enc_bias)

  # Compute per-example reconstruction loss
  reconstruction_error = x - x_reconstructed
  reconstruction_loss = jnp.sum(reconstruction_error**2, axis=-1)

  # Compute per-example sparsity loss
  threshold = jnp.exp(params.log_threshold)
  l0 = jnp.sum(step(pre_activations, threshold), axis=-1)
  sparsity_loss = sparsity_coefficient * l0

  # Return the batch-wise mean total loss
  return jnp.mean(reconstruction_loss + sparsity_loss, axis=0)

In [None]:
# @title Training loop

optimizer = optax.adam(LEARNING_RATE, b1=ADAM_B1)

def remove_parallel_component(x, v):
    """Returns x with component parallel to v projected away."""
    v_normalised = v / (jnp.linalg.norm(v, axis=-1, keepdims=True) + 1e-6)
    parallel_component = jnp.einsum("...d,...d->...", x, v_normalised)
    return x - parallel_component[..., None] * v_normalised

@functools.partial(jax.jit, donate_argnums=(0, 1), static_argnums=(4, 5))
def update(
    params,
    opt_state,
    x,
    sparsity_coefficient,
    use_pre_enc_bias,
    fix_decoder_norms,
):
    # This is a standard JAX training loop, with the exception we optionally
    # constrain the rows of W_dec to unit norm.
    grads = jax.grad(loss)(params, x, sparsity_coefficient, use_pre_enc_bias)
    if fix_decoder_norms:
        grads.W_dec = remove_parallel_component(grads.W_dec, params.W_dec)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    if fix_decoder_norms:
        W_dec_norms = jnp.linalg.norm(params.W_dec, axis=-1, keepdims=True)
        params.W_dec = params.W_dec / W_dec_norms
    return params, opt_state

def train_jax(
    dataset_iterator,
    sparsity_coefficient,
    use_pre_enc_bias,
    fix_decoder_norms,
):
    params = params_init
    opt_state = optimizer.init(params)
    for batch in dataset_iterator:
        params, opt_state = update(
            params,
            opt_state,
            batch,
            sparsity_coefficient,
            use_pre_enc_bias,
            fix_decoder_norms,
        )
    return params

### Testing the JAX implementation

In this section we check various aspects of this JAX implementation against saved outputs from a previous run (to detect any regressions). If the cell below runs without errors, the outputs of the implementation above match a previous golden run.

In [None]:
# @title Regression tests

# First we do a sanity check to make sure the dataset hasn't changed (e.g.
# of changing random seeds)
np.testing.assert_allclose(
    dataset.ravel()[::13001],
    np.array([-0.26036463,  0.07950339, -0.10517059,  0.11135695]),
)

# Then let's check that the SAE's reconstruction and pre-activation outputs
# match previously calculated values (on a small sample of entries)
sae_initial_reconstruction, sae_initial_pre_activations = sae(
    params_init, dataset[0], use_pre_enc_bias=True
)
np.testing.assert_allclose(
    sae_initial_reconstruction.ravel()[::5001],
    np.array([-0.34825447, -0.12089294, -0.20722952,  1.4071178]),
    atol=1e-5,
)
np.testing.assert_allclose(
    sae_initial_pre_activations.ravel()[::40001],
    np.array([ 0.6627157 , -0.2485388 ,  0.52095324, -0.01966003]),
    atol=1e-5,
)

# We now check the loss and its gradients match previously calculated values
loss_init, loss_grad = jax.value_and_grad(loss)(
    params_init,
    dataset[0],
    sparsity_coefficient=1.2,
    use_pre_enc_bias=True,
)
np.testing.assert_allclose(loss_init, 81.98928, atol=1e-5)
np.testing.assert_allclose(
    loss_grad.b_dec[::7],
    np.array([-1.7614671 , -2.9526186 ,  0.99276465]),
    atol=1e-5,
)
np.testing.assert_allclose(
    loss_grad.b_enc[::37],
    np.array([0.29240608, 0.9655949 , 0.20811056, 1.051967]),
    atol=1e-5,
)
np.testing.assert_allclose(
    loss_grad.log_threshold[::37],
    np.array([-0.01634991, -0.00306892, -0.00966095, -0.0034949]),
    atol=1e-5,
)
np.testing.assert_allclose(
    loss_grad.W_dec.ravel()[::601],
    np.array([0.02478151, -0.01372844,  0.00353589,  0.10577872]),
    atol=1e-5,
)
np.testing.assert_allclose(
    loss_grad.W_enc.ravel()[::601],
    np.array([0.02030803,  0.02437474, -0.01442294,  0.01866631]),
    atol=1e-5,
)

# Finally, we train for three steps (using fixed decoder norms) and check that
# the resulting params matches a previous run
params_jax_trained: Params = train_jax(
    iter(dataset),
    sparsity_coefficient=1.2,
    use_pre_enc_bias=True,
    fix_decoder_norms=True,
)
np.testing.assert_allclose(
    params_jax_trained.b_dec[::7],
    np.array([0.04524906, -0.03720585, -0.03994906]),
    atol=1e-5,
)
np.testing.assert_allclose(
    params_jax_trained.b_enc[::37],
    np.array([0.01241532, -0.03432115, -0.03897703,  0.01650386]),
    atol=1e-5,
)
np.testing.assert_allclose(
    params_jax_trained.log_threshold[::37],
    np.array([-4.5912695, -6.2672825, -5.0748363, -6.345015]),
    atol=1e-5,
)
np.testing.assert_allclose(
    params_jax_trained.W_dec.ravel()[::601],
    np.array([-0.05179337, -0.08356037,  0.02901143, -0.36985442]),
    atol=1e-5,
)
np.testing.assert_allclose(
    params_jax_trained.W_enc.ravel()[::601],
    np.array([0.04418736, -0.1515689 ,  0.14993134, -0.12296465]),
    atol=1e-5,
)

## PyTorch implementation

In this section we translate the JAX implementation defined in the previous section into PyTorch. We'll then check carefully that the PyTorch implementation is consistent with the JAX one, the key test being that training over multiple steps with either implementation (using synthetic data and identical initialisation) yields the same parameters (up to numerical tolerance).

In [None]:
# @title STEs, forward pass and loss function

def rectangle_pt(x):
    return ((x > -0.5) & (x < 0.5)).to(x)


class Step(torch.autograd.Function):
    @staticmethod
    def forward(x, threshold):
        return (x > threshold).to(x)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = 0.0 * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(1.0 / BANDWIDTH)
            * rectangle_pt((x - threshold) / BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad


class JumpReLU(torch.autograd.Function):
    @staticmethod
    def forward(x, threshold):
        return x * (x > threshold).to(x)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = (x > threshold) * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(threshold / BANDWIDTH)
            * rectangle_pt((x - threshold) / BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad


class Sae(nn.Module):
    def __init__(self, sae_width, activations_size, use_pre_enc_bias):
        super().__init__()
        self.use_pre_enc_bias = use_pre_enc_bias
        self.W_enc = nn.Parameter(torch.tensor(W_enc))
        self.b_enc = nn.Parameter(torch.tensor(b_enc))
        self.W_dec = nn.Parameter(torch.tensor(W_dec))
        self.b_dec = nn.Parameter(torch.tensor(b_dec))
        self.log_threshold = nn.Parameter(
            torch.tensor(np.log(threshold))
        )

    def __call__(self, x):
        if self.use_pre_enc_bias:
            x = x - self.b_dec

        pre_activations = x @ self.W_enc + self.b_enc
        threshold = torch.exp(self.log_threshold)
        feature_magnitudes = JumpReLU.apply(pre_activations, threshold)
        x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
        return x_reconstructed, pre_activations


def loss_fn_pt(sae, x, sparsity_coefficient, use_pre_enc_bias):
    x_reconstructed, pre_activations = sae(x)

    # Compute per-example reconstruction loss
    reconstruction_error = x - x_reconstructed
    reconstruction_loss = torch.sum(reconstruction_error**2, dim=-1)

    # Compute per-example sparsity loss
    threshold = torch.exp(sae.log_threshold)
    l0 = torch.sum(Step.apply(pre_activations, threshold), dim=-1)
    sparsity_loss = sparsity_coefficient * l0

    # Return the batch-wise mean total loss
    return torch.mean(reconstruction_loss + sparsity_loss, dim=0)

In [None]:
# @title Training loop

def remove_parallel_component_pt(x, v):
    """Returns x with component parallel to v projected away (in PyTorch)."""
    v_normalised = v / (torch.norm(v, dim=-1, keepdim=True) + 1e-6)
    parallel_component = torch.einsum("...d,...d->...", x, v_normalised)
    return x - parallel_component[..., None] * v_normalised

def train_pt(
    dataset_iterator,
    sparsity_coefficient,
    use_pre_enc_bias,
    fix_decoder_norms,
):
    sae = Sae(SAE_WIDTH, ACTIVATIONS_SIZE, use_pre_enc_bias)
    optimizer = torch.optim.Adam(
        sae.parameters(), lr=LEARNING_RATE, betas=(ADAM_B1, 0.999)
    )
    for batch in dataset_iterator:
        optimizer.zero_grad()
        loss_pt = loss_fn_pt(
            sae, torch.tensor(batch), sparsity_coefficient, use_pre_enc_bias
        )
        loss_pt.backward()
        if fix_decoder_norms:
            sae.W_dec.grad = remove_parallel_component_pt(
                sae.W_dec.grad, sae.W_dec.data
            )
        optimizer.step()
        if fix_decoder_norms:
            sae.W_dec.data = sae.W_dec.data / torch.norm(
                sae.W_dec.data, dim=-1, keepdim=True
            )
    return sae

### Testing consistency with the JAX implementation

We start with an end-to-end test: if we train using both implementations on identical (synthetic data) with identical initial parameters, do we get the same parameters at the end of training?

In [None]:
# @title End-to-end test

# Parameter combinations that we'll loop over
sparsity_coefficients = [0.0, 0.01, 0.1]  # Arbitrarily chosen
use_pre_enc_bias_l = [True, False]
fix_decoder_norms_l = [True, False]


for sparsity_coefficient, use_pre_enc_bias, fix_decoder_norms in itertools.product(
    sparsity_coefficients, use_pre_enc_bias_l, fix_decoder_norms_l
):
    print(
        f"Testing {sparsity_coefficient=}, {use_pre_enc_bias=}, "
        f"{fix_decoder_norms=}... ",
        end="",
        flush=True,
    )

    # Train using the JAX implementation
    params_jax_trained = train_jax(
        iter(dataset),
        sparsity_coefficient=sparsity_coefficient,
        use_pre_enc_bias=use_pre_enc_bias,
        fix_decoder_norms=fix_decoder_norms,
    )

    # Train using the PyTorch implementation
    sae_pt_trained = train_pt(
        iter(dataset),
        sparsity_coefficient=sparsity_coefficient,
        use_pre_enc_bias=use_pre_enc_bias,
        fix_decoder_norms=fix_decoder_norms,
    )

    # First we want to make sure the params have actually evolved, otherwise
    # this test isn't meaningful!
    chex.assert_trees_all_close(
        jax.tree.map(
            lambda x, y: np.mean(np.abs(x - y)) > 0.001,
            params_init,
            params_jax_trained,
        ),
        jax.tree.map(lambda _: True, params_init),
    )

    # Now we check whether the parameters obtained using either implementation
    # are close
    chex.assert_trees_all_close(
        dataclasses.asdict(params_jax_trained),
        jax.tree.map(lambda x: x.numpy(), dict(sae_pt_trained.state_dict())),
    )

    print("OK.")

Testing sparsity_coefficient=0.0, use_pre_enc_bias=True, fix_decoder_norms=True... OK.
Testing sparsity_coefficient=0.0, use_pre_enc_bias=True, fix_decoder_norms=False... OK.
Testing sparsity_coefficient=0.0, use_pre_enc_bias=False, fix_decoder_norms=True... OK.
Testing sparsity_coefficient=0.0, use_pre_enc_bias=False, fix_decoder_norms=False... OK.
Testing sparsity_coefficient=0.01, use_pre_enc_bias=True, fix_decoder_norms=True... OK.
Testing sparsity_coefficient=0.01, use_pre_enc_bias=True, fix_decoder_norms=False... OK.
Testing sparsity_coefficient=0.01, use_pre_enc_bias=False, fix_decoder_norms=True... OK.
Testing sparsity_coefficient=0.01, use_pre_enc_bias=False, fix_decoder_norms=False... OK.
Testing sparsity_coefficient=0.1, use_pre_enc_bias=True, fix_decoder_norms=True... OK.
Testing sparsity_coefficient=0.1, use_pre_enc_bias=True, fix_decoder_norms=False... OK.
Testing sparsity_coefficient=0.1, use_pre_enc_bias=False, fix_decoder_norms=True... OK.
Testing sparsity_coefficient=

We can also do the following more granular tests, which may be helpful e.g. if we don't see consistency between implementations when doing end-to-end training:

* The outputs of the SAE (reconstructions and pre-activations) are close;
* Whether the calculated loss is close;
* Whether the gradients of the loss are close.

Running the following cell, we find that these are the same up to the default tolerances for `np.assert_allclose`:

In [None]:
# @title Additional tests

# First we check that the SAE outputs are the same (using the first batch of
# the dataset and initial parameters)
sae_reconstruction_jax, sae_pre_activations_jax = sae(
    params_init, dataset[0], use_pre_enc_bias=True
)
sae_pt = Sae(SAE_WIDTH, ACTIVATIONS_SIZE, use_pre_enc_bias=True)
sae_reconstruction_pt, sae_pre_activations_pt = sae_pt(
    torch.tensor(dataset[0])
)
np.testing.assert_allclose(
    sae_reconstruction_jax, sae_reconstruction_pt.detach().numpy()
)
np.testing.assert_allclose(
    sae_pre_activations_jax, sae_pre_activations_pt.detach().numpy()
)

# Now we check that the losses and loss gradients match too
loss_jax, loss_grad_jax = jax.value_and_grad(loss)(
    params_init,
    dataset[0],
    sparsity_coefficient=1.2,
    use_pre_enc_bias=True,
)
loss_pt = loss_fn_pt(
    sae_pt,
    torch.tensor(dataset[0]),
    sparsity_coefficient=1.2,
    use_pre_enc_bias=True,
)
loss_pt.backward()

# Check loss is close
np.testing.assert_allclose(loss_jax, loss_pt.detach().numpy())

# Check gradients are close
for field in dataclasses.fields(params_init):
    np.testing.assert_allclose(
        getattr(loss_grad_jax, field.name),
        getattr(sae_pt, field.name).grad.numpy(),
        err_msg=f"grad for {field.name} does not match",
    )