# ST-Tahoe Baseline Predictions Analysis

This notebook analyzes predictions from the pretrained ST-Tahoe model on burn/sham wound healing data.

## Important Notes

‚ö†Ô∏è **ST-Tahoe was trained on DRUG perturbations** (Basak dataset), not burn injury
- Input: 2000 highly variable genes from cell line experiments
- Perturbations: Drug treatments (DMSO control)
- Task: Drug response prediction

üî¨ **Our burn/sham data**:
- Input: 2000-dim SE-600M embeddings (truncated from 2058)
- Perturbations: Burn vs Sham injury
- Task: Wound healing trajectory prediction

**Expected result**: ST-Tahoe predictions should NOT meaningfully differentiate burn from sham, as the model was not trained on this biological context. This serves as a **negative control baseline** for comparison with fine-tuned models.

## 1. Setup

In [None]:
import sys
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist
from scipy.stats import pearsonr

# Add project root
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))
os.chdir(project_root)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 150
plt.rcParams['figure.figsize'] = (12, 8)

print("‚úÖ Environment ready")

## 2. Load ST-Tahoe Predictions

In [None]:
# Load predictions
pred_path = "experiments/st_fine_tuning/results/burn_sham_st_tahoe_predictions.h5ad"

if os.path.exists(pred_path):
    adata_pred = ad.read_h5ad(pred_path)
    print("‚úÖ ST-Tahoe predictions loaded")
    print(f"   Shape: {adata_pred.shape}")
    print(f"   Predictions: {adata_pred.obsm['X_state_2000'].shape}")
    print(f"\n   Conditions: {adata_pred.obs['condition'].value_counts().to_dict()}")
    print(f"   Timepoints: {adata_pred.obs['timepoint'].value_counts().to_dict()}")
    print(f"   Cell types: {adata_pred.obs['cell_types_simple_short'].nunique()} types")
else:
    print("‚ùå Predictions not found!")
    print(f"   Expected path: {pred_path}")
    print("\n   Run inference first:")
    print("   state tx infer --model-dir models/ST-Tahoe ...")

## 3. Prediction Statistics

In [None]:
if 'adata_pred' in locals():
    pred = adata_pred.obsm['X_state_2000']
    
    # Overall statistics
    print("Overall Prediction Statistics:")
    print("=" * 60)
    print(f"Shape: {pred.shape}")
    print(f"Range: [{np.min(pred):.4f}, {np.max(pred):.4f}]")
    print(f"Mean: {np.mean(pred):.4f}")
    print(f"Std: {np.std(pred):.4f}")
    print(f"Sparsity: {(pred == 0).sum() / pred.size * 100:.2f}% zeros")
    
    # By condition
    print("\nBy Condition:")
    print("=" * 60)
    for condition in ['Burn', 'Sham']:
        mask = adata_pred.obs['condition'] == condition
        cond_pred = pred[mask]
        print(f"\n{condition}:")
        print(f"  Cells: {mask.sum()}")
        print(f"  Mean: {np.mean(cond_pred):.4f}")
        print(f"  Std: {np.std(cond_pred):.4f}")
        print(f"  Range: [{np.min(cond_pred):.4f}, {np.max(cond_pred):.4f}]")

## 4. Burn vs Sham Comparison

In [None]:
if 'adata_pred' in locals():
    burn_mask = adata_pred.obs['condition'] == 'Burn'
    sham_mask = adata_pred.obs['condition'] == 'Sham'
    
    burn_pred = pred[burn_mask]
    sham_pred = pred[sham_mask]
    
    # Distance between conditions
    burn_centroid = burn_pred.mean(axis=0)
    sham_centroid = sham_pred.mean(axis=0)
    between_dist = np.linalg.norm(burn_centroid - sham_centroid)
    
    # Within-condition distances (as reference)
    burn_dists = cdist(burn_pred[:100], [burn_centroid]).mean()
    sham_dists = cdist(sham_pred[:100], [sham_centroid]).mean()
    within_dist = (burn_dists + sham_dists) / 2
    
    print("Burn vs Sham Separation:")
    print("=" * 60)
    print(f"Between-condition distance: {between_dist:.4f}")
    print(f"Within-condition distance: {within_dist:.4f}")
    print(f"Separation ratio: {between_dist / within_dist:.4f}")
    print("\n‚ö†Ô∏è  If ratio ‚âà 1.0, burn and sham are NOT differentiated")
    print("   (Expected for ST-Tahoe as it wasn't trained on wound healing)")

## 5. Visualizations

In [None]:
if 'adata_pred' in locals():
    # Compute UMAP
    sc.pp.neighbors(adata_pred, use_rep='X_state_2000', n_neighbors=15)
    sc.tl.umap(adata_pred)
    
    # Plot
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # UMAP by condition
    sc.pl.umap(adata_pred, color='condition', ax=axes[0, 0], show=False, title='By Condition')
    
    # UMAP by timepoint
    sc.pl.umap(adata_pred, color='timepoint', ax=axes[0, 1], show=False, title='By Timepoint')
    
    # UMAP by cell type
    sc.pl.umap(adata_pred, color='cell_types_simple_short', ax=axes[0, 2], show=False, title='By Cell Type')
    
    # Magnitude distribution
    magnitudes = np.linalg.norm(pred, axis=1)
    burn_mag = magnitudes[burn_mask]
    sham_mag = magnitudes[sham_mask]
    
    axes[1, 0].hist(burn_mag, bins=50, alpha=0.5, label='Burn', color='#E74C3C')
    axes[1, 0].hist(sham_mag, bins=50, alpha=0.5, label='Sham', color='#3498DB')
    axes[1, 0].set_xlabel('Prediction Magnitude')
    axes[1, 0].set_ylabel('Count')
    axes[1, 0].set_title('Prediction Magnitude Distribution')
    axes[1, 0].legend()
    
    # Mean prediction by cell type
    mean_by_celltype = adata_pred.obs.groupby(['cell_types_simple_short', 'condition']).size().unstack(fill_value=0)
    mean_by_celltype.plot(kind='bar', ax=axes[1, 1], color=['#E74C3C', '#3498DB'])
    axes[1, 1].set_xlabel('Cell Type')
    axes[1, 1].set_ylabel('Cell Count')
    axes[1, 1].set_title('Cell Type Distribution')
    axes[1, 1].legend(title='Condition')
    axes[1, 1].tick_params(axis='x', rotation=45)
    
    # Sparsity
    sparsity_burn = (burn_pred == 0).sum(axis=1) / burn_pred.shape[1] * 100
    sparsity_sham = (sham_pred == 0).sum(axis=1) / sham_pred.shape[1] * 100
    
    axes[1, 2].hist(sparsity_burn, bins=50, alpha=0.5, label='Burn', color='#E74C3C')
    axes[1, 2].hist(sparsity_sham, bins=50, alpha=0.5, label='Sham', color='#3498DB')
    axes[1, 2].set_xlabel('Sparsity (%)')
    axes[1, 2].set_ylabel('Count')
    axes[1, 2].set_title('Prediction Sparsity')
    axes[1, 2].legend()
    
    plt.tight_layout()
    plt.savefig('experiments/st_fine_tuning/results/st_tahoe_baseline_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n‚úÖ Visualizations saved to: experiments/st_fine_tuning/results/st_tahoe_baseline_analysis.png")

## 6. Summary and Conclusion

In [None]:
print("=" * 80)
print("ST-TAHOE BASELINE SUMMARY")
print("=" * 80)

print("\nüìä Key Findings:")
if 'adata_pred' in locals():
    print(f"  - Processed {adata_pred.shape[0]:,} cells")
    print(f"  - Burn cells: {burn_mask.sum():,}")
    print(f"  - Sham cells: {sham_mask.sum():,}")
    print(f"  - Burn/Sham separation ratio: {between_dist / within_dist:.3f}")
    print(f"  - Prediction statistics nearly identical (see above)")

print("\n‚ö†Ô∏è  Expected Limitations:")
print("  1. ST-Tahoe trained on DRUG perturbations, not burn injury")
print("  2. Model uses perturbation vocabulary (drug names), not burn/sham")
print("  3. Predictions do NOT capture wound healing biology")
print("  4. Burn and Sham predictions are nearly identical")

print("\n‚úÖ Use as Baseline:")
print("  - Negative control for comparison")
print("  - Fine-tuned models (ST-LoRA variants) should SIGNIFICANTLY outperform")
print("  - Expected improvement: higher burn/sham separation, better correlations")

print("\nüìù Next Steps:")
print("  1. Train ST-LoRA models (see train_all_st_variants.ipynb)")
print("  2. Compare fine-tuned vs baseline (see compare_st_results.ipynb)")
print("  3. Validate biological predictions (macrophage polarization, etc.)")

print("\n" + "=" * 80)