# Truly Sparse SAE Training for Feature Overlap Analysis

This notebook trains a properly sparse autoencoder with stronger sparsity constraints.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

class SparseSAE(nn.Module):
    """Sparse Autoencoder with L1 penalty and top-k activation."""
    
    def __init__(self, input_dim, hidden_dim=2048, sparsity_coeff=0.1, top_k=50):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
        self.sparsity_coeff = sparsity_coeff  # Increased from 0.01
        self.top_k = top_k  # Only keep top k features active
        
        # Tie weights
        self.decoder.weight = nn.Parameter(self.encoder.weight.t())
        
        # Better initialization for sparsity
        nn.init.xavier_uniform_(self.encoder.weight, gain=0.1)
        nn.init.zeros_(self.encoder.bias)
    
    def encode(self, x):
        h = torch.relu(self.encoder(x))
        
        # Apply top-k sparsity (only during training)
        if self.training and self.top_k is not None:
            # Keep only top k activations per sample
            topk_vals, topk_idx = torch.topk(h, self.top_k, dim=1)
            mask = torch.zeros_like(h)
            mask.scatter_(1, topk_idx, 1)
            h = h * mask
        
        return h
    
    def forward(self, x):
        code = self.encode(x)
        recon = self.decoder(code)
        return recon, code
    
    def loss(self, x, beta_warm=1.0):
        recon, code = self.forward(x)
        
        # Reconstruction loss
        recon_loss = nn.functional.mse_loss(recon, x)
        
        # L1 sparsity penalty (with warm-up)
        l1_loss = code.abs().mean()
        
        # L0 penalty (number of active features)
        l0_loss = (code > 0).float().mean()
        
        # Combined loss
        total_loss = recon_loss + beta_warm * self.sparsity_coeff * l1_loss
        
        return total_loss, {
            'total': total_loss.item(),
            'recon': recon_loss.item(),
            'l1': l1_loss.item(),
            'l0': l0_loss.item(),
            'active_features': (code > 0).float().sum(dim=1).mean().item()
        }

print("SparseSAE class defined with:")
print("- Stronger sparsity coefficient (0.1 vs 0.01)")
print("- Top-k activation constraint")
print("- Better initialization for sparsity")

In [None]:
# Training with better sparsity
def train_sparse_sae(activations, device='cuda', 
                     hidden_dim=2048, 
                     sparsity_coeff=0.1,  # 10x stronger
                     top_k=30,  # Only 30 active features max
                     epochs=200,  # More epochs
                     lr=0.0005):  # Lower learning rate
    
    input_dim = activations.shape[1]
    sae = SparseSAE(input_dim, hidden_dim, sparsity_coeff, top_k).to(device)
    
    optimizer = optim.Adam(sae.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    
    acts_tensor = activations.to(device)
    
    print(f"Training Sparse SAE: {input_dim} → {hidden_dim} features")
    print(f"Sparsity: L1 coefficient={sparsity_coeff}, Top-k={top_k}")
    print("="*60)
    
    losses = []
    for epoch in range(epochs):
        # Warm up sparsity penalty
        beta_warm = min(1.0, epoch / 50)  # Warm up over 50 epochs
        
        # Forward pass
        loss, metrics = sae.loss(acts_tensor, beta_warm)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
        
        optimizer.step()
        scheduler.step()
        
        losses.append(metrics)
        
        if epoch % 40 == 0:
            print(f"Epoch {epoch:3d}: Loss={metrics['total']:.4f} "
                  f"(R:{metrics['recon']:.4f} L1:{metrics['l1']:.4f}) "
                  f"Active={metrics['active_features']:.1f}")
    
    print("="*60)
    print(f"✓ Final active features per sample: {metrics['active_features']:.1f}")
    print(f"✓ Final L0 (fraction active): {metrics['l0']:.4f}")
    
    return sae, losses

# This will be called after loading activations
print("Training function ready. Will enforce true sparsity.")

## Key Changes for Proper Sparsity:

1. **Sparsity coefficient: 0.1** (was 0.01) - 10x stronger L1 penalty
2. **Top-k constraint: 30** - Maximum 30 features can be active (out of 2048)
3. **Warm-up schedule** - Gradually increase sparsity penalty
4. **Better initialization** - Xavier with gain=0.1 for sparser start
5. **200 epochs** - More training for features to specialize

Expected results:
- ~10-30 active features per sentence (not 140!)
- Clearer separation between conditional-specific and quantifier-specific features
- More meaningful overlap percentage