# Ensemble models analysis

This notebook provides interactive analysis of pre-trained ensemble models for combinatorial perturbation prediction.

## 1. Setup and Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
import os
import json
import warnings
import torch
warnings.filterwarnings('ignore')

# ensemble framework
from ensemble import Ensemble
from ensemble_analyze import EnsembleAnalyzer

plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

ModuleNotFoundError: No module named 'ensemble'

In [3]:
import os
print(os.getcwd())

/content


In [2]:
SCLAMBDA_REPO = '/Users/rikac/Documents/scLAMBDA'
DATA_DIR = '/Users/rikac/Documents/ml_stats/combo-pert-pred/data'
MODEL_DIR = '/Users/rikac/Documents/ml_stats/combo-pert-pred/models'

# model paths
GEARS_MODEL_DIR = f'{MODEL_DIR}/gears_model'
GEARS_DATA_PATH = f'{MODEL_DIR}/gears_data'
SCLAMBDA_MODEL_PATH = f'{MODEL_DIR}/sclambda_model'
SCLAMBDA_ADATA_PATH = f'{DATA_DIR}/norman_perturbseq_preprocessed_hvg_filtered.h5ad'
SCLAMBDA_EMBEDDINGS_PATH = f'{DATA_DIR}/GPT_3_5_gene_embeddings_3-large.pickle'
NORMAN_DATA_PATH = f'{DATA_DIR}/norman_perturbseq_preprocessed_hvg_filtered.h5ad'

OUTPUT_DIR = '../ensemble_results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

## 2. Load pre-trained models

This loads:
- pre-trained GEARS model
- pre-trained scLAMBDA model
- baseline models

In [None]:
print('Loading ensemble models...')
print('=' * 70)

# initialize ensemble
ensemble = Ensemble(sclambda_repo_path=SCLAMBDA_REPO)

# need to change to cpu for macbook
sclambda_checkpoint = torch.load(f'{SCLAMBDA_MODEL_PATH}/ckpt.pth', map_location=torch.device('cpu'))

# load all models
ensemble.load_models_with_device(
    gears_model_dir=GEARS_MODEL_DIR,
    gears_data_path=GEARS_DATA_PATH,
    gears_data_name='norman',
    sclambda_model_or_path=sclambda_checkpoint,
    sclambda_adata_path=SCLAMBDA_ADATA_PATH,
    sclambda_embeddings_path=SCLAMBDA_EMBEDDINGS_PATH,
    norman_data_path=NORMAN_DATA_PATH
)

print('\n All models loaded successfully!')
print(f'Number of genes: {len(ensemble.gene_names)}')
print(f'Single perturbations: {ensemble.X_single.shape[0]}')
print(f'Combo perturbations: {ensemble.X_combo.shape[0]}')

Loading ensemble models...


AttributeError: 'Ensemble' object has no attribute 'load_models_with_device'

## 4. Create Data Splits

Split data for evaluation (GEARS-style: train on singles + some combos, test on held-out combos)

In [None]:
print("Creating data splits...")

splits = ensemble.data_processor.create_combo_splits(
    X_single=ensemble.X_single,
    y_single=ensemble.y_single,
    X_combo=ensemble.X_combo,
    y_combo=ensemble.y_combo,
    combo_test_ratio=0.2,
    random_state=42
)

print(f"\nSplit created:")
print(f"  Training samples: {splits['X_train'].shape[0]} (includes all {ensemble.X_single.shape[0]} singles)")
print(f"  Validation samples: {splits['X_val'].shape[0]}")
print(f"  Test samples: {splits['X_test'].shape[0]}")
print(f"  Split type: {splits['split_type']}")

## 5. Initialize Analyzer

Create the analyzer object that will handle all evaluations and visualizations.

In [None]:
print("Initializing analyzer...")

analyzer = EnsembleAnalyzer(
    ensemble=ensemble, 
    splits=splits, 
    output_dir=OUTPUT_DIR
)

print(f"‚úÖ Analyzer initialized")
print(f"   Output directory: {analyzer.output_dir}")

## 6. Evaluate Individual Models

Get predictions from all models and compute performance metrics.

In [None]:
# Evaluate all models
metrics = analyzer.evaluate_individual_models()

# Create a nice summary table
metrics_df = pd.DataFrame(metrics).T
metrics_df = metrics_df.round(6)

print("\n" + "="*70)
print("PERFORMANCE SUMMARY")
print("="*70)
display(metrics_df)

# Highlight best performers
print("\nüèÜ Best Performers:")
print(f"   Lowest MSE: {metrics_df['mse'].idxmin()} ({metrics_df['mse'].min():.6f})")
print(f"   Highest Pearson r: {metrics_df['pearson_r'].idxmax()} ({metrics_df['pearson_r'].max():.6f})")
print(f"   Highest R¬≤: {metrics_df['r2'].idxmax()} ({metrics_df['r2'].max():.6f})")

## 7. Model Comparison Visualization

Generate comprehensive comparison plots across all models.

In [None]:
analyzer.plot_model_comparison()

# Display the saved plot
from IPython.display import Image, display
display(Image(filename=f'{OUTPUT_DIR}/model_comparison.png'))

## 8. Uncertainty Distribution Analysis

Analyze epistemic uncertainty across test samples.

In [None]:
per_sample_uncertainty = analyzer.plot_uncertainty_distribution()

# Display the saved plot
display(Image(filename=f'{OUTPUT_DIR}/uncertainty_distribution.png'))

# Print uncertainty statistics
print("\n" + "="*70)
print("UNCERTAINTY STATISTICS")
print("="*70)
print(f"Mean uncertainty: {np.mean(per_sample_uncertainty):.4f}")
print(f"Median uncertainty: {np.median(per_sample_uncertainty):.4f}")
print(f"Std uncertainty: {np.std(per_sample_uncertainty):.4f}")
print(f"Min uncertainty: {np.min(per_sample_uncertainty):.4f}")
print(f"Max uncertainty: {np.max(per_sample_uncertainty):.4f}")
print(f"95th percentile: {np.percentile(per_sample_uncertainty, 95):.4f}")

## 9. Uncertainty Calibration

Check if high uncertainty correlates with high prediction error (indicates well-calibrated uncertainty).

In [None]:
analyzer.plot_uncertainty_vs_error()

# Display the saved plot
display(Image(filename=f'{OUTPUT_DIR}/uncertainty_vs_error.png'))

## 10. Generate Summary Report

Save comprehensive text and JSON summaries.

In [None]:
analyzer.save_summary()

# Display the summary
with open(f'{OUTPUT_DIR}/summary.txt', 'r') as f:
    print(f.read())

## 11. Interactive Exploration: Make Custom Predictions

Test the ensemble on specific perturbations of interest.

In [None]:
# Example: predict a specific combo perturbation
# Modify gene names below to test specific perturbations

gene1 = 'CBL'  # Replace with gene of interest
gene2 = 'CNN1'  # Replace with gene of interest

# Create perturbation vector
X_custom = np.zeros((1, len(ensemble.gene_names)))

if gene1 in ensemble.gene_names and gene2 in ensemble.gene_names:
    idx1 = ensemble.gene_names.index(gene1)
    idx2 = ensemble.gene_names.index(gene2)
    X_custom[0, [idx1, idx2]] = 1.0
    
    # Get predictions
    pred_mean, uncertainties, individual_preds = ensemble.predict_ensemble(X_custom)
    
    print(f"\nPrediction for: {gene1} + {gene2}")
    print("="*60)
    print(f"\nIndividual model predictions (mean across genes):")
    for model_name, preds in individual_preds.items():
        print(f"  {model_name:12s}: {np.mean(preds[0]):.4f}")
    
    print(f"\nEnsemble prediction (mean): {np.mean(pred_mean[0]):.4f}")
    print(f"Total uncertainty: {np.sum(uncertainties[0]):.4f}")
    print(f"Mean uncertainty per gene: {np.mean(uncertainties[0]):.6f}")
    
    # Visualize prediction distribution across genes
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Predicted expression changes
    top_genes = 20
    sorted_idx = np.argsort(np.abs(pred_mean[0]))[-top_genes:]
    
    axes[0].barh(range(top_genes), pred_mean[0][sorted_idx], 
                color=['red' if x < 0 else 'blue' for x in pred_mean[0][sorted_idx]])
    axes[0].set_yticks(range(top_genes))
    axes[0].set_yticklabels([f'Gene {i}' for i in sorted_idx], fontsize=8)
    axes[0].set_xlabel('Predicted Expression Change', fontweight='bold')
    axes[0].set_title(f'Top {top_genes} Affected Genes: {gene1}+{gene2}', fontweight='bold')
    axes[0].axvline(0, color='black', linestyle='--', alpha=0.5)
    axes[0].grid(axis='x', alpha=0.3)
    
    # Plot 2: Uncertainty per gene
    sorted_unc_idx = np.argsort(uncertainties[0])[-top_genes:]
    
    axes[1].barh(range(top_genes), uncertainties[0][sorted_unc_idx], color='orange')
    axes[1].set_yticks(range(top_genes))
    axes[1].set_yticklabels([f'Gene {i}' for i in sorted_unc_idx], fontsize=8)
    axes[1].set_xlabel('Epistemic Uncertainty', fontweight='bold')
    axes[1].set_title(f'Top {top_genes} Most Uncertain Genes', fontweight='bold')
    axes[1].grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
else:
    print(f"Error: One or both genes not found in dataset")
    print(f"Available genes: {ensemble.gene_names[:10]}...")  # Show first 10

## 12. Experiment Recommendations

Identify high-priority experiments based on epistemic uncertainty.

In [None]:
# Get all test set predictions with uncertainties
_, test_uncertainties, test_individual_preds = ensemble.predict_ensemble(splits['X_test'])
test_uncertainty_scores = np.sum(test_uncertainties, axis=1)

# Find top uncertain samples
n_recommend = 10
top_uncertain_idx = np.argsort(test_uncertainty_scores)[-n_recommend:]

print(f"\nüî¨ TOP {n_recommend} RECOMMENDED EXPERIMENTS (Highest Uncertainty)")
print("="*70)

recommendations_data = []

for rank, idx in enumerate(reversed(top_uncertain_idx), 1):
    perturbation = splits['X_test'][idx]
    perturbed_gene_idx = np.where(perturbation > 0)[0]
    perturbed_genes = [ensemble.gene_names[i] for i in perturbed_gene_idx]
    uncertainty = test_uncertainty_scores[idx]
    
    # Get model predictions
    model_preds = {}
    for model_name, preds in test_individual_preds.items():
        model_preds[model_name] = np.mean(preds[idx])
    
    print(f"\n{rank:2d}. Perturbation: {' + '.join(perturbed_genes)}")
    print(f"    Uncertainty score: {uncertainty:.4f}")
    print(f"    Model predictions (mean expression):")
    for model_name, pred_val in model_preds.items():
        print(f"      {model_name:12s}: {pred_val:7.4f}")
    print(f"    ‚Üí Model disagreement = high learning potential")
    
    recommendations_data.append({
        'rank': rank,
        'genes': ' + '.join(perturbed_genes),
        'uncertainty': uncertainty,
        **{f'{m}_pred': model_preds[m] for m in model_preds.keys()}
    })

# Create recommendations dataframe
recommendations_df = pd.DataFrame(recommendations_data)
print("\n" + "="*70)
print("RECOMMENDATIONS SUMMARY")
print("="*70)
display(recommendations_df.round(4))

# Save recommendations
recommendations_df.to_csv(f'{OUTPUT_DIR}/experiment_recommendations.csv', index=False)
print(f"\n‚úÖ Recommendations saved to: {OUTPUT_DIR}/experiment_recommendations.csv")

## 13. Model Agreement Analysis

Examine how well different models agree with each other.

In [None]:
# Compute pairwise correlations between models
models = ['gears', 'sclambda', 'mean', 'additive', 'ensemble']

# Get ensemble prediction
ensemble_pred = np.mean(np.stack([
    analyzer.predictions['gears'],
    analyzer.predictions['sclambda'],
    analyzer.predictions['mean'],
    analyzer.predictions['additive']
], axis=0), axis=0)

all_preds = {
    'gears': analyzer.predictions['gears'],
    'sclambda': analyzer.predictions['sclambda'],
    'mean': analyzer.predictions['mean'],
    'additive': analyzer.predictions['additive'],
    'ensemble': ensemble_pred
}

# Compute correlation matrix
n_models = len(models)
corr_matrix = np.zeros((n_models, n_models))

for i, model1 in enumerate(models):
    for j, model2 in enumerate(models):
        corr, _ = pearsonr(all_preds[model1].flatten(), all_preds[model2].flatten())
        corr_matrix[i, j] = corr

# Visualize correlation matrix
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(corr_matrix, cmap='RdYlBu_r', vmin=0, vmax=1)

# Set ticks and labels
ax.set_xticks(np.arange(n_models))
ax.set_yticks(np.arange(n_models))
ax.set_xticklabels(models)
ax.set_yticklabels(models)

# Rotate the tick labels
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Add correlation values
for i in range(n_models):
    for j in range(n_models):
        text = ax.text(j, i, f'{corr_matrix[i, j]:.3f}',
                      ha="center", va="center", color="black", fontweight='bold')

ax.set_title("Model Agreement: Pairwise Correlations", fontsize=14, fontweight='bold', pad=20)
fig.colorbar(im, ax=ax, label='Pearson Correlation')
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/model_agreement.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nüìä Model Agreement Analysis:")
print(f"   Average pairwise correlation: {np.mean(corr_matrix[np.triu_indices(n_models, k=1)]):.3f}")
print(f"   Most similar models: ", end="")
max_corr_idx = np.unravel_index(np.argmax(corr_matrix + np.eye(n_models) * -10), corr_matrix.shape)
print(f"{models[max_corr_idx[0]]} vs {models[max_corr_idx[1]]} (r={corr_matrix[max_corr_idx]:.3f})")

## 14. Summary and Conclusions

Key findings from the analysis.

In [None]:
print("\n" + "="*70)
print("üéâ ENSEMBLE ANALYSIS COMPLETE")
print("="*70)

# Best model
best_model = metrics_df['mse'].idxmin()
best_mse = metrics_df['mse'].min()
ensemble_mse = metrics_df.loc['ensemble', 'mse']
ensemble_r2 = metrics_df.loc['ensemble', 'r2']

print(f"\nüìä Overall Performance:")
print(f"   Best individual model: {best_model} (MSE: {best_mse:.6f})")
print(f"   Ensemble performance: MSE: {ensemble_mse:.6f}, R¬≤: {ensemble_r2:.6f}")

if ensemble_mse < best_mse:
    improvement = (best_mse - ensemble_mse) / best_mse * 100
    print(f"   ‚ú® Ensemble improves over best individual by {improvement:.2f}%")

print(f"\nüî¨ Uncertainty Insights:")
print(f"   Mean uncertainty: {np.mean(per_sample_uncertainty):.4f}")
print(f"   High-priority experiments identified: {n_recommend}")

# Compute uncertainty-error correlation
y_test = splits['y_test']
per_sample_error = np.mean((ensemble_pred - y_test) ** 2, axis=1)
unc_err_corr, p_val = pearsonr(test_uncertainty_scores, per_sample_error)

print(f"\nüìà Calibration:")
print(f"   Uncertainty-error correlation: {unc_err_corr:.3f} (p={p_val:.2e})")
if unc_err_corr > 0.3 and p_val < 0.05:
    print(f"   ‚úÖ Well-calibrated: High uncertainty ‚Üí High error")
elif unc_err_corr > 0 and p_val < 0.05:
    print(f"   ‚ö†Ô∏è  Moderately calibrated")
else:
    print(f"   ‚ö†Ô∏è  Uncertainty may need recalibration")

print(f"\nüíæ Generated Files:")
print(f"   - {OUTPUT_DIR}/summary.txt")
print(f"   - {OUTPUT_DIR}/metrics.json")
print(f"   - {OUTPUT_DIR}/model_comparison.png")
print(f"   - {OUTPUT_DIR}/uncertainty_distribution.png")
print(f"   - {OUTPUT_DIR}/uncertainty_vs_error.png")
print(f"   - {OUTPUT_DIR}/model_agreement.png")
print(f"   - {OUTPUT_DIR}/experiment_recommendations.csv")

print("\n" + "="*70)
print("‚úÖ Analysis notebook complete!")
print("="*70)