### TopK Sparse Autoencoder for Protein-Ligand Docking Analysis

**Goal**: Train a TopK Sparse Autoencoder on 30D VAE latent vectors to identify interpretable features distinguishing native-like poses (RMSD <2Å) from poor poses (RMSD >2Å).

**Data**: ~6,000-7,000 poses per protein system, filtered to generations 0-7

In [None]:
# Imports and setup
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, roc_auc_score
from scipy.stats import spearmanr
import pickle
import glob
import os
from pathlib import Path

from schrodinger.structure import StructureReader

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Matplotlib inline for Jupyter
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# GPU setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
# Data Loading Section
# Load pickle files containing latent vectors and pose quality metrics

def load_pickle_data(data_dir, max_gen=7):
    """
    Load latent vectors and metadata from pickle files.
    
    Args:
        data_dir: Directory containing .pkl files
        max_gen: Maximum generation number to include (default: 7)
    
    Returns:
        latents: numpy array of shape [N, 30] - latent vectors
        rmsd: numpy array of shape [N] - ligand RMSD values
        energy: numpy array of shape [N] - energy scores
        gen: numpy array of shape [N] - generation numbers
    """
    # Search for pickle files recursively in subdirectories
    pkl_files = sorted(glob.glob(os.path.join(data_dir, '**/*.pkl'), recursive=True))
    if not pkl_files:
        # Fallback to non-recursive search
        pkl_files = sorted(glob.glob(os.path.join(data_dir, '*.pkl')))
    print(f"Found {len(pkl_files)} pickle files")
    
    latents_list = []
    rmsd_list = []
    energy_list = []
    gen_list = []
    
    for file in pkl_files:
        try:
            with open(file, 'rb') as f:
                data = pickle.load(f)
            
            # Handle different possible key names for latent vector
            if 'z' in data:
                z = data['z']
            elif 'latent' in data:
                z = data['latent']
            else:
                print(f"Warning: No 'z' or 'latent' key in {file}, skipping")
                continue
            
            # Handle different possible structures for scores
            if isinstance(data.get('ligand_rmsd'), (int, float, np.number)):
                rmsd = data['ligand_rmsd']
            elif 'scores' in data and 'lig_rmsd' in data['scores']:
                rmsd = data['scores']['lig_rmsd']
            elif 'scores' in data and 'ligand_rmsd' in data['scores']:
                rmsd = data['scores']['ligand_rmsd']
            else:
                print(f"Warning: No RMSD found in {file}, skipping")
                continue
            
            # Handle energy
            if isinstance(data.get('energy'), (int, float, np.number)):
                energy = data['energy']
            elif 'scores' in data and 'energy' in data['scores']:
                energy = data['scores']['energy']
            else:
                energy = np.nan
            
            # Handle generation
            if isinstance(data.get('curr_gen'), (int, np.integer)):
                gen = data['curr_gen']
            else:
                gen = 0  # Default to generation 0 if not specified
            
            # Filter by generation
            if gen > max_gen:
                continue
            
            # Convert to numpy if needed
            if isinstance(z, torch.Tensor):
                z = z.cpu().numpy()
            if not isinstance(z, np.ndarray):
                z = np.array(z)
            
            # Ensure z is 1D array of length 30
            if z.ndim > 1:
                z = z.flatten()
            if len(z) != 30:
                print(f"Warning: Latent vector has shape {z.shape}, expected 30, skipping")
                continue
            
            latents_list.append(z)
            rmsd_list.append(float(rmsd))
            energy_list.append(float(energy) if not np.isnan(energy) else 0.0)
            gen_list.append(int(gen))
            
        except Exception as e:
            print(f"Error loading {file}: {e}")
            continue
    
    latents = np.array(latents_list)
    rmsd = np.array(rmsd_list)
    energy = np.array(energy_list)
    gen = np.array(gen_list)
    
    print(f"\nLoaded {len(latents)} samples")
    print(f"Latent vectors shape: {latents.shape}")
    print(f"RMSD range: {rmsd.min():.2f} - {rmsd.max():.2f} Å")
    print(f"Good poses (RMSD < 2Å): {(rmsd < 2.0).sum()} ({(rmsd < 2.0).mean()*100:.1f}%)")
    print(f"Poor poses (RMSD >= 2Å): {(rmsd >= 2.0).sum()} ({(rmsd >= 2.0).mean()*100:.1f}%)")
    print(f"Generation range: {gen.min()} - {gen.max()}")
    
    return latents, rmsd, energy, gen

# Load data from all protein systems
data_base_dir = '../data'
data_dirs = [
    'pim1_3vbt_pim1_4lmu_optimization',
    'pim1_4lmu_pim1_4bzo_withRL',
    'rho_2esm_rho_2etk_optimization'
]

# Load data from all directories
all_latents = []
all_rmsd = []
all_energy = []
all_gen = []

for data_subdir in data_dirs:
    data_dir = os.path.join(data_base_dir, data_subdir)
    print(f"\n{'='*60}")
    print(f"Loading from: {data_subdir}")
    print(f"{'='*60}")
    latents, rmsd, energy, gen = load_pickle_data(data_dir, max_gen=7)
    all_latents.append(latents)
    all_rmsd.append(rmsd)
    all_energy.append(energy)
    all_gen.append(gen)

# Concatenate all data
latents = np.concatenate(all_latents, axis=0)
rmsd = np.concatenate(all_rmsd, axis=0)
energy = np.concatenate(all_energy, axis=0)
gen = np.concatenate(all_gen, axis=0)

print(f"\n{'='*60}")
print(f"COMBINED DATA SUMMARY")
print(f"{'='*60}")
print(f"Total samples: {len(latents)}")
print(f"Latent vectors shape: {latents.shape}")
print(f"RMSD range: {rmsd.min():.2f} - {rmsd.max():.2f} Å")
print(f"Good poses (RMSD < 2Å): {(rmsd < 2.0).sum()} ({(rmsd < 2.0).mean()*100:.1f}%)")
print(f"Poor poses (RMSD >= 2Å): {(rmsd >= 2.0).sum()} ({(rmsd >= 2.0).mean()*100:.1f}%)")
print(f"Generation range: {gen.min()} - {gen.max()}")


## Preprocessing: Normalization and Train/Val Split


In [None]:
# Preprocessing: Z-score normalization and stratified train/val split

# Create binary labels for stratification (RMSD < 2Å = good pose)
is_good = (rmsd < 2.0).astype(int)

# Z-score normalization
scaler = StandardScaler()
latents_normalized = scaler.fit_transform(latents)
print(f"Normalized latents shape: {latents_normalized.shape}")
print(f"Normalized latents mean: {latents_normalized.mean(axis=0).mean():.6f}")
print(f"Normalized latents std: {latents_normalized.std(axis=0).mean():.6f}")

# 70/30 train/val split stratified by RMSD < 2Å label
X_train, X_val, y_train_rmsd, y_val_rmsd, y_train_label, y_val_label = train_test_split(
    latents_normalized, rmsd, is_good,
    test_size=0.3,
    random_state=42,
    stratify=is_good
)

print(f"\nTrain set: {len(X_train)} samples")
print(f"  Good poses: {(y_train_label == 1).sum()} ({(y_train_label == 1).mean()*100:.1f}%)")
print(f"  Poor poses: {(y_train_label == 0).sum()} ({(y_train_label == 0).mean()*100:.1f}%)")

print(f"\nVal set: {len(X_val)} samples")
print(f"  Good poses: {(y_val_label == 1).sum()} ({(y_val_label == 1).mean()*100:.1f}%)")
print(f"  Poor poses: {(y_val_label == 0).sum()} ({(y_val_label == 0).mean()*100:.1f}%)")

# Convert to tensors
X_train_tensor = torch.FloatTensor(X_train).to(device)
X_val_tensor = torch.FloatTensor(X_val).to(device)
print(f"\nTensor shapes:")
print(f"X_train_tensor: {X_train_tensor.shape}")
print(f"X_val_tensor: {X_val_tensor.shape}")


## TopK Sparse Autoencoder Model Definition


In [None]:
# TopK Sparse Autoencoder following Adams et al. (2025)
# Architecture: 30D input -> 120D hidden (TopK K=6) -> 30D output

class TopKSAE(nn.Module):
    """
    TopK Sparse Autoencoder.
    
    Architecture:
        Input (30D) -> Linear -> Hidden (120D) -> TopK(K=6) -> Linear -> Output (30D)
    
    The TopK operation keeps only the K largest activations per sample, zeroing out the rest.
    This enforces sparsity without requiring a sparsity penalty term.
    """
    
    def __init__(self, input_dim=30, hidden_dim=120, k=6):
        super(TopKSAE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.k = k  # Number of active features (K=6 gives ~5% sparsity for 120D)
        
        # Encoder: input -> hidden
        self.encoder = nn.Linear(input_dim, hidden_dim)
        
        # Decoder: hidden -> output
        self.decoder = nn.Linear(hidden_dim, input_dim)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize weights using Xavier uniform initialization."""
        nn.init.xavier_uniform_(self.encoder.weight)
        nn.init.zeros_(self.encoder.bias)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.zeros_(self.decoder.bias)
    
    def forward(self, x):
        """
        Forward pass with TopK sparsity.
        
        Args:
            x: Input tensor of shape [batch_size, input_dim]
        
        Returns:
            reconstructed: Reconstructed output [batch_size, input_dim]
            h_sparse: Sparse hidden activations [batch_size, hidden_dim]
        """
        # Encode to hidden layer
        h = self.encoder(x)  # [batch_size, hidden_dim]
        
        # Apply TopK: keep only K largest activations per sample
        # topk_vals: [batch_size, k] - the K largest values
        # topk_indices: [batch_size, k] - indices of the K largest values
        topk_vals, topk_indices = torch.topk(h, self.k, dim=-1)
        
        # Create sparse hidden representation: zeros everywhere except top K
        h_sparse = torch.zeros_like(h)
        h_sparse.scatter_(-1, topk_indices, topk_vals)
        
        # Decode to output
        reconstructed = self.decoder(h_sparse)  # [batch_size, input_dim]
        
        return reconstructed, h_sparse
    
    def encode(self, x):
        """Encode input to sparse hidden representation."""
        h = self.encoder(x)
        topk_vals, topk_indices = torch.topk(h, self.k, dim=-1)
        h_sparse = torch.zeros_like(h)
        h_sparse.scatter_(-1, topk_indices, topk_vals)
        return h_sparse

# Initialize model
model = TopKSAE(input_dim=30, hidden_dim=120, k=6).to(device)
print(f"Model initialized:")
print(f"  Input dim: {model.input_dim}")
print(f"  Hidden dim: {model.hidden_dim}")
print(f"  K (sparsity): {model.k} ({model.k/model.hidden_dim*100:.1f}% active)")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters())}")

# Test forward pass
with torch.no_grad():
    test_input = X_train_tensor[:5]
    test_recon, test_hidden = model(test_input)
    print(f"\nTest forward pass:")
    print(f"  Input shape: {test_input.shape}")
    print(f"  Hidden shape: {test_hidden.shape}")
    print(f"  Reconstructed shape: {test_recon.shape}")
    print(f"  Hidden sparsity: {(test_hidden == 0).float().mean().item()*100:.1f}%")


## Training Loop


In [None]:
# Training loop: 100 epochs, Adam optimizer, lr=1e-3

# Setup
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training history
train_losses = []
val_losses = []

# Batch size
batch_size = 64
n_epochs = 100

print(f"Starting training for {n_epochs} epochs...")
print(f"Batch size: {batch_size}")
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")

for epoch in range(n_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    n_batches = 0
    
    # Shuffle training data
    indices = torch.randperm(len(X_train_tensor))
    
    for i in range(0, len(X_train_tensor), batch_size):
        batch_indices = indices[i:i+batch_size]
        batch_x = X_train_tensor[batch_indices]
        
        # Forward pass
        optimizer.zero_grad()
        reconstructed, h_sparse = model(batch_x)
        
        # Loss: MSE between input and reconstruction
        loss = criterion(reconstructed, batch_x)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        n_batches += 1
    
    avg_train_loss = train_loss / n_batches
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    n_val_batches = 0
    
    with torch.no_grad():
        for i in range(0, len(X_val_tensor), batch_size):
            batch_x = X_val_tensor[i:i+batch_size]
            reconstructed, h_sparse = model(batch_x)
            loss = criterion(reconstructed, batch_x)
            val_loss += loss.item()
            n_val_batches += 1
    
    avg_val_loss = val_loss / n_val_batches
    val_losses.append(avg_val_loss)
    
    # Print progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}: Train Loss = {avg_train_loss:.6f}, Val Loss = {avg_val_loss:.6f}")

print("\nTraining complete!")

# Save model
torch.save(model.state_dict(), 'topk_sae.pt')
print("Model saved to 'topk_sae.pt'")


## Training Loss Visualization


In [None]:
# Plot training and validation loss curves
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Train Loss', linewidth=2)
plt.plot(val_losses, label='Val Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('MSE Loss', fontsize=12)
plt.title('TopK SAE Training Loss', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final train loss: {train_losses[-1]:.6f}")
print(f"Final val loss: {val_losses[-1]:.6f}")


## Feature Extraction: Get Sparse Hidden Activations


In [1]:
# Extract sparse hidden activations for all data
model.eval()
all_features = []

# Process in batches to avoid memory issues
batch_size_extract = 256
all_latents = torch.FloatTensor(latents_normalized).to(device)

with torch.no_grad():
    for i in range(0, len(all_latents), batch_size_extract):
        batch_x = all_latents[i:i+batch_size_extract]
        features = model.encode(batch_x)  # [batch_size, 120]
        all_features.append(features.cpu().numpy())

# Concatenate all features
all_features = np.concatenate(all_features, axis=0)
print(f"Extracted features shape: {all_features.shape}")
print(f"Feature sparsity: {(all_features == 0).sum() / all_features.size * 100:.1f}%")
print(f"Average active features per sample: {(all_features != 0).sum(axis=1).mean():.2f} (expected: {model.k})")


NameError: name 'model' is not defined

## Feature Analysis: Correlations with RMSD


In [2]:
# Calculate Spearman correlations between each feature and RMSD
feature_correlations = []
feature_pvalues = []

for feat_idx in range(all_features.shape[1]):
    corr, pval = spearmanr(all_features[:, feat_idx], rmsd)
    feature_correlations.append(corr)
    feature_pvalues.append(pval)

feature_correlations = np.array(feature_correlations)
feature_pvalues = np.array(feature_pvalues)

# Sort features by absolute correlation
sorted_indices = np.argsort(np.abs(feature_correlations))[::-1]

print("Top 10 features by absolute correlation with RMSD:")
print("Feature | Correlation | P-value")
print("-" * 40)
for i in range(min(10, len(sorted_indices))):
    idx = sorted_indices[i]
    print(f"  {idx:3d}   |  {feature_correlations[idx]:+7.4f}   | {feature_pvalues[idx]:.2e}")

# Features with significant correlations (p < 0.05)
significant_features = np.where(feature_pvalues < 0.05)[0]
print(f"\nSignificant features (p < 0.05): {len(significant_features)} out of {all_features.shape[1]}")


NameError: name 'all_features' is not defined

## Visualizations: Feature Activation Heatmap


In [None]:
# Feature activation heatmap: top 10 features, sorted by RMSD
top_n_features = 10
top_feature_indices = sorted_indices[:top_n_features]

# Sort samples by RMSD
sorted_by_rmsd = np.argsort(rmsd)
n_samples_plot = min(500, len(rmsd))  # Plot up to 500 samples for clarity
sample_indices = sorted_by_rmsd[::max(1, len(rmsd)//n_samples_plot)][:n_samples_plot]

# Extract feature activations for top features and selected samples
heatmap_data = all_features[np.ix_(sample_indices, top_feature_indices)]

# Create heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(heatmap_data.T, 
            xticklabels=[f"RMSD={rmsd[i]:.2f}" for i in sample_indices[::max(1, len(sample_indices)//20)]],
            yticklabels=[f"Feat {idx}" for idx in top_feature_indices],
            cmap='RdYlBu_r', center=0,
            cbar_kws={'label': 'Activation'})
plt.xlabel('Samples (sorted by RMSD)', fontsize=11)
plt.ylabel('Top Features (by |correlation|)', fontsize=11)
plt.title('Top 10 SAE Feature Activations vs RMSD', fontsize=13, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


## Visualizations: Scatter Plots of Top 3 Features vs RMSD


In [None]:
# Scatter plots of top 3 features vs RMSD
top_3_features = sorted_indices[:3]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, feat_idx in enumerate(top_3_features):
    ax = axes[i]
    scatter = ax.scatter(all_features[:, feat_idx], rmsd, 
                        c=rmsd, cmap='viridis', alpha=0.5, s=10)
    ax.set_xlabel(f'Feature {feat_idx} Activation', fontsize=11)
    ax.set_ylabel('RMSD (Å)', fontsize=11)
    ax.set_title(f'Feature {feat_idx}\n(corr={feature_correlations[feat_idx]:.3f})', 
                 fontsize=11, fontweight='bold')
    ax.grid(True, alpha=0.3)
    plt.colorbar(scatter, ax=ax, label='RMSD (Å)')

plt.suptitle('Top 3 SAE Features vs RMSD', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()


## Baseline Classifier: Logistic Regression on SAE Features


In [None]:
# Train logistic regression on SAE features
# Binary classification: RMSD < 2Å (good) vs RMSD >= 2Å (poor)

# Prepare data
is_good_pose = (rmsd < 2.0).astype(int)

# Split features into train/val (same split as SAE training)
X_train_features = all_features[:len(X_train)]
X_val_features = all_features[len(X_train):len(X_train)+len(X_val)]
y_train_binary = is_good_pose[:len(X_train)]
y_val_binary = is_good_pose[len(X_train):len(X_train)+len(X_val)]

print(f"Training logistic regression on SAE features...")
print(f"Train: {len(X_train_features)} samples, {y_train_binary.sum()} good poses")
print(f"Val: {len(X_val_features)} samples, {y_val_binary.sum()} good poses")

# Train logistic regression
lr_sae = LogisticRegression(max_iter=1000, random_state=42)
lr_sae.fit(X_train_features, y_train_binary)

# Predictions
y_train_pred_sae = lr_sae.predict_proba(X_train_features)[:, 1]
y_val_pred_sae = lr_sae.predict_proba(X_val_features)[:, 1]

# Calculate auPR
train_aupr_sae = average_precision_score(y_train_binary, y_train_pred_sae)
val_aupr_sae = average_precision_score(y_val_binary, y_val_pred_sae)

print(f"\nSAE Features Classifier:")
print(f"  Train auPR: {train_aupr_sae:.4f}")
print(f"  Val auPR: {val_aupr_sae:.4f}")


## Baseline Comparison: Logistic Regression on Raw Latent Vectors


In [None]:
# Train logistic regression on raw 30D latent vectors (no SAE) for comparison
X_train_raw = latents_normalized[:len(X_train)]
X_val_raw = latents_normalized[len(X_train):len(X_train)+len(X_val)]

print(f"Training logistic regression on raw latent vectors...")
print(f"Train: {len(X_train_raw)} samples")
print(f"Val: {len(X_val_raw)} samples")

# Train logistic regression
lr_raw = LogisticRegression(max_iter=1000, random_state=42)
lr_raw.fit(X_train_raw, y_train_binary)

# Predictions
y_train_pred_raw = lr_raw.predict_proba(X_train_raw)[:, 1]
y_val_pred_raw = lr_raw.predict_proba(X_val_raw)[:, 1]

# Calculate auPR
train_aupr_raw = average_precision_score(y_train_binary, y_train_pred_raw)
val_aupr_raw = average_precision_score(y_val_binary, y_val_pred_raw)

print(f"\nRaw Latent Vectors Classifier:")
print(f"  Train auPR: {train_aupr_raw:.4f}")
print(f"  Val auPR: {val_aupr_raw:.4f}")


## Results Summary


In [None]:
# Print comprehensive results summary
print("=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)

print(f"\n1. MODEL ARCHITECTURE:")
print(f"   - Input dimension: 30D")
print(f"   - Hidden dimension: 120D")
print(f"   - TopK sparsity: K={model.k} ({model.k/model.hidden_dim*100:.1f}% active)")
print(f"   - Final train loss: {train_losses[-1]:.6f}")
print(f"   - Final val loss: {val_losses[-1]:.6f}")

print(f"\n2. TOP CORRELATED FEATURES WITH RMSD:")
print(f"   (Features with highest |Spearman correlation|)")
for i in range(min(10, len(sorted_indices))):
    idx = sorted_indices[i]
    sig = "*" if feature_pvalues[idx] < 0.05 else " "
    print(f"   {sig} Feature {idx:3d}: corr = {feature_correlations[idx]:+7.4f}, p = {feature_pvalues[idx]:.2e}")

print(f"\n3. CLASSIFICATION PERFORMANCE (auPR):")
print(f"   Raw 30D Latent Vectors:")
print(f"     Train auPR: {train_aupr_raw:.4f}")
print(f"     Val auPR:   {val_aupr_raw:.4f}")
print(f"   SAE Features (120D, TopK={model.k}):")
print(f"     Train auPR: {train_aupr_sae:.4f}")
print(f"     Val auPR:   {val_aupr_sae:.4f}")
print(f"   Improvement: {val_aupr_sae - val_aupr_raw:+.4f} ({((val_aupr_sae/val_aupr_raw - 1)*100):+.1f}%)")

print(f"\n4. DATA STATISTICS:")
print(f"   Total samples: {len(latents)}")
print(f"   Good poses (RMSD < 2Å): {(rmsd < 2.0).sum()} ({(rmsd < 2.0).mean()*100:.1f}%)")
print(f"   Poor poses (RMSD >= 2Å): {(rmsd >= 2.0).sum()} ({(rmsd >= 2.0).mean()*100:.1f}%)")
print(f"   RMSD range: {rmsd.min():.2f} - {rmsd.max():.2f} Å")

print("\n" + "=" * 60)
