In [None]:
import torch
from tqdm import tqdm

from mech_interp import ParallelToyModel
from mech_interp.visualizations import plot_feature_analysis
from mech_interp.data_generators import SyntheticSparseDataGenerator, create_sparsity_range
from mech_interp.importance import importance_decay_by_ratio
from mech_interp.loss import weighted_mse_loss

In [None]:
# ==== Parameters ====

feature_dim = 80 
hidden_dim = 20

importance_decay = 0.9
min_sparsity = 0.0
max_sparsity = 0.999

num_models = 10
batch_size = 1024
num_steps = 10_000
learning_rate = 1e-3

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# ==== Set up ====

sparsity = create_sparsity_range(min_sparsity, max_sparsity, num_models, feature_dim)
importance = importance_decay_by_ratio(feature_dim, importance_decay).to(device)
model = ParallelToyModel(num_models, feature_dim, hidden_dim).to(device)
data_generator = SyntheticSparseDataGenerator(batch_size=batch_size, sparsity=sparsity, device=device)

In [None]:
# ==== Training ====

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

progress_bar = tqdm(range(num_steps))

for step in progress_bar:
    optimizer.zero_grad(set_to_none=True)

    batch = data_generator.generate_batch()

    output, _ = model(batch)

    loss = weighted_mse_loss(output, batch, importance)
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix(loss=loss.item())

In [None]:
torch.save(model.state_dict(), "toy_models_80.pth")

In [None]:
model.load_state_dict(torch.load("toy_models_80.pth"))

In [None]:
# ==== Visualize Results ====

weights = model.get_feature_directions()
bias = model.get_bias()

sparcity_list = [sparsity[i][0].item() for i in range(num_models)]
labels = [f'Sparsity = {sparcity:.3f}' for sparcity in sparcity_list]

fig = plot_feature_analysis(weights, bias, labels=labels)
fig.show()