In [None]:
import torch
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F

from mech_interp import ToyModel
from mech_interp.data_generators import SyntheticSparseDataGenerator
from mech_interp.sparse_autoencoder import TopKSparseAutoencoder
from mech_interp.script_utils import create_uniform_sparsity, create_importance
from mech_interp.geometric_median import geometric_median
from mech_interp.visualizations import plot_feature_directions

In [None]:

model_path = 'toy_model_for_sae.pth'
feature_dim = 5
hidden_dim = 2
sparsity = 0.9
importance_decay = 0.9

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
torch.set_float32_matmul_precision('high')

In [None]:
uniform_sparsity = create_uniform_sparsity(feature_dim, sparsity)
importance = create_importance(feature_dim, importance_decay)

model = ToyModel(feature_dim, hidden_dim)
model.load_state_dict(torch.load(model_path))
model.requires_grad_(False)
model.to(device);


# Setup SAE

In [None]:

# Collect activations for SAE initialisation
num_samples = 10_000
sample_data = SyntheticSparseDataGenerator(
                batch_size=num_samples,
                sparsity=uniform_sparsity,
                device=device
                ).generate_batch()
_, activations = model(sample_data)

# Following the paper, we initialise the SAE bias as the geometric median of the activations.
initial_sae_bias = geometric_median(activations)

In [None]:
sae = TopKSparseAutoencoder(activations_dim=hidden_dim,
                            feature_dim=feature_dim,
                            initial_bias=initial_sae_bias,
                            k=1
                            ).to(device)

In [None]:
sae_directions = sae.get_feature_directions().cpu()
fig = plot_feature_directions(sae_directions, ["SAE"], importance, eps=3e-2)
fig.show()


# Train SAE

In [None]:
iterations = 20_000
plot_interval = 2500
learning_rate = 1e-2
batch_size = 4096
betas = (0.0, 0.999)

In [None]:
model_feature_directions = model.get_feature_directions().cpu()
feature_directions = [model_feature_directions]
labels = ['Target']

In [None]:
data_generator = SyntheticSparseDataGenerator(
                batch_size=batch_size,
                sparsity=uniform_sparsity,
                device=device
                )

optimizer = torch.optim.Adam(sae.parameters(), lr=learning_rate, betas=betas)

progress_bar = tqdm(range(iterations))
for step in progress_bar:
    if step % plot_interval == 0:
        sae_directions = sae.get_feature_directions().cpu()
        fig = plot_feature_directions(sae_directions, ["SAE"], importance, eps=3e-2)
        fig.show()
        feature_directions.append(sae.get_feature_directions().cpu())
        labels.append(f'SAE@{step}')

    optimizer.zero_grad(set_to_none=True)

    # Collect activations
    batch = data_generator.generate_batch()
    _, activations = model(batch)

    reconstruction, features = sae(activations)

    reconstruction_loss = F.mse_loss(activations, reconstruction)
    loss = reconstruction_loss

    loss.backward()
    optimizer.step()

    progress_bar.set_postfix(loss=f"{loss.item():.3f}", reconstruction_loss=f"{reconstruction_loss.item():.3f}")

sae_directions = sae.get_feature_directions().cpu()
fig = plot_feature_directions(sae_directions, ["SAE"], importance, eps=3e-2)
fig.show()
feature_directions.append(sae_directions)
labels.append(f'SAE@{iterations}')

# Visualize

In [None]:
feature_directions = torch.stack(feature_directions, dim=0)
fig = plot_feature_directions(feature_directions, labels, importance, eps=3e-2)
fig.show()