# Sparse Autoencoder Interpretability Analysis

This notebook provides an interactive environment for analyzing the results of the Sparse Autoencoder trained on protein-ligand docking poses. We'll explore:

1. **Feature Analysis**: Understanding which latent features correlate with pose quality
2. **Sparsity Patterns**: Analyzing the sparsity structure learned by the SAE
3. **Visualization**: Exploring the latent space and feature relationships
4. **Comparison**: Comparing SAE performance with traditional methods like PCA
5. **Biological Interpretation**: Connecting learned features to molecular properties

## Scientific Motivation

The goal is to identify which features in the 30D latent space are most predictive of successful protein-ligand docking poses (<2Å RMSD) versus failures (≥3Å RMSD). This will help:
- Filter poor poses early in VAE-diffusion docking pipelines
- Understand what makes a docking pose successful
- Guide future model development and feature engineering


In [None]:
# Import required libraries
import sys
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import json
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.linear_model import LogisticRegression
import warnings
warnings.filterwarnings('ignore')

# Add src directory to path
sys.path.append(str(Path.cwd().parent / "src"))

from data_loader import load_docking_data, create_sample_data
from model import SparseAutoencoder, create_model
from analysis import SAEAnalyzer
from utils import set_seed, get_device, load_config

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

print("Libraries imported successfully!")


## 1. Load Data and Model

First, let's load the configuration, data, and trained model.


In [None]:
# Load configuration
config_path = "../configs/config.yaml"
config = load_config(config_path)
print("Configuration loaded:")
print(f"Model hidden dim: {config['model']['hidden_dim']}")
print(f"Sparsity lambda: {config['model']['sparsity_lambda']}")
print(f"Data path: {config['data']['data_path']}")


In [None]:
# Load data
data_path = config['data']['data_path']
if not os.path.exists(data_path):
    print("Creating sample data...")
    create_sample_data(data_path)

train_loader, val_loader, test_loader, scaler = load_docking_data(data_path, config['data'])
print(f"Data loaded successfully!")
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


In [None]:
# Load trained model (you'll need to train first or provide path to trained model)
model_path = "../models/best_model.pt"  # Update this path as needed

if os.path.exists(model_path):
    device = get_device()
    checkpoint = torch.load(model_path, map_location=device)
    model = create_model(config['model'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    print(f"Model loaded from: {model_path}")
    print(f"Training completed at epoch: {checkpoint.get('epoch', 'Unknown')}")
    print(f"Best validation loss: {checkpoint.get('best_val_loss', 'Unknown')}")
else:
    print(f"Model not found at {model_path}. Please train a model first.")
    print("You can train a model by running: python main.py --mode train")


## 2. Feature Extraction and Basic Analysis

Let's extract features from the trained model and perform basic analysis.


In [None]:
# Create analyzer
analyzer = SAEAnalyzer(model, device)

# Extract features from test set
print("Extracting features from test set...")
latents, hidden_features, ga_rankings, rmsd_values, quality_labels = analyzer.extract_features(test_loader)

print(f"Extracted features shape: {hidden_features.shape}")
print(f"Quality distribution: {np.bincount(quality_labels.astype(int))}")
print(f"RMSD range: {rmsd_values.min():.2f} - {rmsd_values.max():.2f} Å")


In [None]:
# Analyze sparsity patterns
sparsity_analysis = analyzer.analyze_sparsity(hidden_features, threshold=0.1)

print("Sparsity Analysis:")
print(f"Overall sparsity: {sparsity_analysis['overall_sparsity']:.3f}")
print(f"Most sparse features: {sparsity_analysis['most_sparse_features']}")
print(f"Least sparse features: {sparsity_analysis['least_sparse_features']}")

# Plot sparsity distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Sparsity per sample
axes[0].hist(sparsity_analysis['sparsity_per_sample'], bins=50, alpha=0.7, edgecolor='black')
axes[0].set_xlabel('Sparsity (fraction of inactive features)')
axes[0].set_ylabel('Number of samples')
axes[0].set_title('Sparsity Distribution Across Samples')
axes[0].grid(True, alpha=0.3)

# Sparsity per feature
axes[1].hist(sparsity_analysis['sparsity_per_feature'], bins=50, alpha=0.7, edgecolor='black')
axes[1].set_xlabel('Sparsity (fraction of inactive samples)')
axes[1].set_ylabel('Number of features')
axes[1].set_title('Sparsity Distribution Across Features')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
