In [None]:
import torch
from tqdm import tqdm

from mech_interp import ToyModel
from mech_interp.visualizations import plot_feature_analysis
from mech_interp.data_generators import SyntheticClusteredDataGenerator 
from mech_interp.script_utils import create_importance, weighted_mse_loss

In [None]:
# ==== Parameters ====

n_topics = 10

feature_dim = 70 
hidden_dim = 25

alpha = 0.5
beta_params = (0.15, 1.0)

importance_decay = 0.99 # TODO: Investigate how this parameter affects the model

batch_size = 1024
num_steps = 10_000
learning_rate = 1e-3

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# ==== Set up ====

importance = create_importance(feature_dim, importance_decay).to(device)
model = ToyModel(feature_dim, hidden_dim).to(device)
data_generator = SyntheticClusteredDataGenerator(n_topics=n_topics, n_features=feature_dim, alpha=alpha, beta_params=beta_params, device=device)

In [None]:
# # ==== Training ====

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

progress_bar = tqdm(range(num_steps))

for step in progress_bar:
    optimizer.zero_grad(set_to_none=True)

    batch = data_generator.generate_batch(batch_size)

    output, _ = model(batch)

    loss = weighted_mse_loss(output, batch, importance)
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix(loss=loss.item())

In [None]:
torch.save(model.state_dict(), "toy_model_for_sae_big.pth")

In [None]:
model.load_state_dict(torch.load("toy_model_for_sae_big.pth"))

In [None]:
# ==== Visualize Results ====

weights = model.get_feature_directions()
bias = model.get_bias()

fig = plot_feature_analysis(weights, bias)
fig.show()