In [None]:
import torch
from tqdm import tqdm

from mech_interp import ToyModel
from mech_interp.visualizations import plot_feature_directions
from mech_interp.data_generators import SyntheticSparseDataGenerator
from mech_interp.script_utils import create_uniform_sparsity, create_importance, weighted_mse_loss

In [None]:
# ==== Parameters ====

feature_dim = 5
hidden_dim = 2

importance_decay = 0.9
sparsity = 0.9

batch_size = 1024
num_steps = 10_000
learning_rate = 1e-3

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# ==== Set up ====

uniform_sparsity = create_uniform_sparsity(feature_dim, sparsity)
importance = create_importance(feature_dim, importance_decay).to(device)
model = ToyModel(feature_dim, hidden_dim).to(device)
data_generator = SyntheticSparseDataGenerator(batch_size=batch_size, sparsity=uniform_sparsity, device=device)

In [None]:
# ==== Training ====

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

progress_bar = tqdm(range(num_steps))

for step in progress_bar:
    optimizer.zero_grad(set_to_none=True)

    batch = data_generator.generate_batch()

    output, _ = model(batch)

    loss = weighted_mse_loss(output, batch, importance)
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix(loss=loss.item())

In [None]:
# ==== Plot Feature Directions ====

feature_directions = model.get_feature_directions().cpu()
labels = [f'Sparsity = {sparsity:.3f}']
fig = plot_feature_directions(feature_directions, labels, importance, eps=3e-2)
fig.show()

In [None]:
torch.save(model.state_dict(), "toy_model_for_sae.pth")