In [1]:
import torch
import pickle
import os
import numpy as np
import itertools
from models import TopKAutoEncoder
from utils import MCC
from IPython.display import clear_output
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
import json

In [2]:
# Set device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
# Configuration dictionary for all hyperparameters
config = {
    'k': 3,                    # sparsity parameter
    'activation_dim': 16,      # input/output dimension
    'dict_size': 32,           # latent dimension
    'N': 50_000,               # number of data points
    'steps': 30_000,           # number of training steps
    'lr': 1e-4,                # learning rate
    'l1_coeff': 1e-2,          # coefficient for L1 loss
    'data_seed': 42,           # random seed for data
    'num_models': 5,           # number of models to train with different seeds
    'base_train_seed': 42,     # base random seed for training
    'log_interval': 100,       # interval for logging and visualization
}

In [4]:
# Generate synthetic data
with open(f"data/synthetic_data_n{config['activation_dim']}_m{config['dict_size']}_N{config['N']}_k{config['k']}_seed{config['data_seed']}.pkl", "rb") as f:
    data = pickle.load(f)
A_true, S_true, X = data['A'], data['S'], data['X']
assert A_true.shape[0] == config['dict_size'], f"Expected {config['dict_size']} rows in A_true, got {A_true.shape[0]}"

In [5]:
# Initialize multiple models with different seeds
models = []
optimizers = []

for i in range(config['num_models']):
    # Set seed for this model
    model_seed = config['base_train_seed'] + i
    torch.manual_seed(model_seed)
    
    # Create model and optimizer
    model = TopKAutoEncoder(activation_dim=config['activation_dim'], dict_size=config['dict_size'], k=config['k']).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    
    models.append(model)
    optimizers.append(optimizer)

# Convert data to tensor once
x = torch.tensor(X, dtype=torch.float32).to(device)

# Initialize tracking lists
model_mcc_lists = [[] for _ in range(config['num_models'])]  # MCC between each model and ground truth
model_total_loss_lists = [[] for _ in range(config['num_models'])]  # Total loss for each model
model_l2_loss_lists = [[] for _ in range(config['num_models'])]  # L2 loss for each model
model_l1_loss_lists = [[] for _ in range(config['num_models'])]  # L1 loss for each model
model_l0_loss_lists = [[] for _ in range(config['num_models'])]  # L0 loss for each model
pairwise_mcc_means = []  # Average MCC among models
pairwise_mcc_lists = []  # Store pairwise MCCs
steps_list = []  # Keep track of steps for plotting

# Define colors for consistent visualization
colors = ['blue', 'red', 'green', 'purple', 'orange', 'cyan', 'magenta', 'brown', 'pink', 'grey']
if config['num_models'] > len(colors):
    # Generate additional colors if needed
    from itertools import cycle
    colors = list(cycle(colors))[:config['num_models']]

In [6]:
# Training loop
for step in tqdm(range(config['steps'])):
    # Train each model for one step
    for i, (model, optimizer) in enumerate(zip(models, optimizers)):
        optimizer.zero_grad()
        
        # Forward pass
        x_hat, f = model(x, output_features=True)
        
        # Calculate losses
        rec_loss = torch.mean((x - x_hat) ** 2)
        l1_loss = torch.sum(torch.abs(f), dim=1).mean()
        l0_loss = torch.sum((f != 0).float(), dim=1).mean()  # Just for tracking
        
        # Total loss
        total_loss = rec_loss + config['l1_coeff'] * l1_loss
        
        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()
        
        # Optional: normalize decoder weights
        # model.normalize_decoder()
    
    # Log and visualize at intervals
    if step % config['log_interval'] == 0:
        # Store step for plotting
        steps_list.append(step)
        
        # Get all model weights
        A_estimates = [model.decoder.weight.data.T for model in models]
        
        # Calculate MCC between each model and ground truth and store loss values
        for i, (model, optimizer, A_est) in enumerate(zip(models, optimizers, A_estimates)):
            # Calculate MCC against ground truth
            mcc = MCC(A_true, A_est, dict_size=config['dict_size'])
            model_mcc_lists[i].append(mcc)
            
            # Forward pass to get latest loss values
            x_hat, f = model(x, output_features=True)
            rec_loss = torch.mean((x - x_hat) ** 2).item()
            l1_loss = torch.sum(torch.abs(f), dim=1).mean().item()
            l0_loss = torch.sum((f != 0).float(), dim=1).mean().item()
            total_loss = rec_loss + config['l1_coeff'] * l1_loss
            
            # Store loss values
            model_total_loss_lists[i].append(total_loss)
            model_l2_loss_lists[i].append(rec_loss)
            model_l1_loss_lists[i].append(l1_loss)
            model_l0_loss_lists[i].append(l0_loss)
        
        # Calculate pairwise MCCs between models
        pairwise_mccs = []
        for (i, A_i), (j, A_j) in itertools.combinations(enumerate(A_estimates), 2):
            pair_mcc = MCC(A_i, A_j, dict_size=config['dict_size'])
            pairwise_mccs.append(pair_mcc)
        
        # Calculate mean of pairwise MCCs
        mean_pairwise_mcc = np.mean(pairwise_mccs)
        pairwise_mcc_means.append(mean_pairwise_mcc)
        pairwise_mcc_lists.append(pairwise_mccs)
        
        # Clear output and create visualization
        clear_output(wait=True)
        
        # Create a 2x2 subplot grid for different metrics
        fig = make_subplots(
            rows=2, 
            cols=2,
            subplot_titles=(
                "MCC vs Ground Truth", 
                "Mean Pairwise MCC",
                "Total Loss", 
                "Component Losses"
            ),
            vertical_spacing=0.12,
            horizontal_spacing=0.08
        )
        
        # 1. Plot MCC vs Ground Truth (top-left)
        for i in range(config['num_models']):
            fig.add_trace(
                go.Scatter(
                    x=steps_list,
                    y=model_mcc_lists[i],
                    mode='lines',
                    name=f"Model {i+1} MCC",
                    line=dict(color=colors[i]),
                    legendgroup=f"model_{i}"
                ),
                row=1, col=1
            )
            
        # 2. Plot Mean Pairwise MCC (top-right)
        fig.add_trace(
            go.Scatter(
                x=steps_list,
                y=pairwise_mcc_means,
                mode='lines',
                name=f"Mean Pairwise MCC",
                line=dict(color='black', width=3),
                legendgroup="pairwise"
            ),
            row=1, col=2
        )
            
        # 3. Plot Total Loss (bottom-left)
        for i in range(config['num_models']):
            fig.add_trace(
                go.Scatter(
                    x=steps_list,
                    y=model_total_loss_lists[i],
                    mode='lines',
                    name=f"Model {i+1} Total Loss",
                    line=dict(color=colors[i]),
                    legendgroup=f"model_{i}",
                    showlegend=False
                ),
                row=2, col=1
            )
        
        # 4. Plot Component Losses (L1, L2, L0) for first model only to avoid clutter
        # We'll use the most recent model for component loss visualization
        i = 0  # First model for component losses
        fig.add_trace(
            go.Scatter(
                x=steps_list,
                y=model_l2_loss_lists[i],
                mode='lines',
                name=f"L2 Loss (Model {i+1})",
                line=dict(color='blue')
            ),
            row=2, col=2
        )
        
        fig.add_trace(
            go.Scatter(
                x=steps_list,
                y=model_l1_loss_lists[i],
                mode='lines',
                name=f"L1 Loss (Model {i+1})",
                line=dict(color='red')
            ),
            row=2, col=2
        )
        
        fig.add_trace(
            go.Scatter(
                x=steps_list,
                y=model_l0_loss_lists[i],
                mode='lines',
                name=f"L0 Loss (Model {i+1})",
                line=dict(color='green')
            ),
            row=2, col=2
        )
        
        # Update layout
        fig.update_layout(
            title_text=f"Training Progress (Step {step}/{config['steps']})",
            height=800,
            width=1200,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="center",
                x=0.5
            )
        )
        
        # Update axis titles
        fig.update_xaxes(title_text="Steps", row=2, col=1)
        fig.update_xaxes(title_text="Steps", row=2, col=2)
        fig.update_xaxes(title_text="Steps", row=1, col=1)
        fig.update_xaxes(title_text="Steps", row=1, col=2)
        
        fig.update_yaxes(title_text="MCC", row=1, col=1)
        fig.update_yaxes(title_text="MCC", row=1, col=2)
        fig.update_yaxes(title_text="Total Loss", row=2, col=1)
        fig.update_yaxes(title_text="Loss Value", row=2, col=2)
        
        fig.show()
if not os.path.exists("output"):
    os.makedirs("output")
with open("output/top_k_sae.pkl", "wb") as f:
    result_dict = {
        "model_mcc_lists": model_mcc_lists,
        "model_total_loss_lists": model_total_loss_lists,
        "model_l2_loss_lists": model_l2_loss_lists,
        "model_l1_loss_lists": model_l1_loss_lists,
        "model_l0_loss_lists": model_l0_loss_lists,
        "pairwise_mcc_means": pairwise_mcc_means,
        "pairwise_mcc_lists": pairwise_mcc_lists,
        "steps_list": steps_list
    }
    pickle.dump(result_dict, f)