# ST Model Results Comparison

Comprehensive comparison of all 5 ST model variants:

| Model | LoRA | mHC | Velocity | Status |
|-------|------|-----|----------|--------|
| **ST-Tahoe (Baseline)** | ‚ùå | ‚ùå | ‚ùå | Pretrained |
| **ST-LoRA** | ‚úÖ | ‚ùå | ‚ùå | Fine-tuned |
| **ST-LoRA-mHC** | ‚úÖ | ‚úÖ | ‚ùå | Fine-tuned |
| **ST-LoRA-Velocity** | ‚úÖ | ‚ùå | ‚úÖ | Fine-tuned |
| **ST-LoRA-mHC-Velocity** | ‚úÖ | ‚úÖ | ‚úÖ | Fine-tuned |

## Evaluation Metrics

1. **Training Dynamics**: Loss curves, gradient stability, convergence
2. **Prediction Accuracy**: Gene correlation, cell-wise distance
3. **Biological Validity**: Wound healing trajectories, cell-type responses
4. **Efficiency**: Training time, memory usage, parameter count

## 1. Setup

In [None]:
import sys
import os
from pathlib import Path
import glob

import torch
import anndata as ad
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import cdist
from sklearn.metrics import mean_squared_error, r2_score

# Add project root
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))
os.chdir(project_root)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 150
plt.rcParams['figure.figsize'] = (12, 8)

print("‚úÖ Environment ready")

## 2. Load Data and Predictions

In [None]:
# Define model variants
model_configs = [
    {
        'name': 'st_tahoe_baseline',
        'display_name': 'ST-Tahoe (Baseline)',
        'prediction_file': 'experiments/st_fine_tuning/results/burn_sham_st_tahoe_predictions.h5ad',
        'color': '#95A5A6',
        'marker': 'o',
    },
    {
        'name': 'st_lora',
        'display_name': 'ST-LoRA',
        'prediction_file': '/home/scumpia-mrl/state_models/st_lora/predictions/burn_sham_predictions.h5ad',
        'model_dir': '/home/scumpia-mrl/state_models/st_lora',
        'color': '#3498DB',
        'marker': 's',
    },
    {
        'name': 'st_lora_mhc',
        'display_name': 'ST-LoRA-mHC',
        'prediction_file': '/home/scumpia-mrl/state_models/st_lora_mhc/predictions/burn_sham_predictions.h5ad',
        'model_dir': '/home/scumpia-mrl/state_models/st_lora_mhc',
        'color': '#E74C3C',
        'marker': '^',
    },
    {
        'name': 'st_lora_velocity',
        'display_name': 'ST-LoRA-Velocity',
        'prediction_file': '/home/scumpia-mrl/state_models/st_lora_velocity/predictions/burn_sham_predictions.h5ad',
        'model_dir': '/home/scumpia-mrl/state_models/st_lora_velocity',
        'color': '#2ECC71',
        'marker': 'D',
    },
    {
        'name': 'st_lora_mhc_velocity',
        'display_name': 'ST-LoRA-mHC-Velocity',
        'prediction_file': '/home/scumpia-mrl/state_models/st_lora_mhc_velocity/predictions/burn_sham_predictions.h5ad',
        'model_dir': '/home/scumpia-mrl/state_models/st_lora_mhc_velocity',
        'color': '#9B59B6',
        'marker': '*',
    },
]

# Load predictions
predictions = {}
for config in model_configs:
    if os.path.exists(config['prediction_file']):
        predictions[config['name']] = ad.read_h5ad(config['prediction_file'])
        print(f"‚úÖ Loaded {config['display_name']}")
    else:
        print(f"‚è≥ {config['display_name']} - predictions not found")

print(f"\n‚úì Loaded {len(predictions)}/{len(model_configs)} models")

## 3. Training Dynamics Comparison

In [None]:
def load_training_logs(model_dir):
    """Load training metrics from CSV logs."""
    csv_files = glob.glob(f"{model_dir}/lightning_logs/version_*/metrics.csv")
    if not csv_files:
        return None
    
    # Load most recent version
    csv_file = sorted(csv_files)[-1]
    df = pd.read_csv(csv_file)
    return df

# Load training logs
training_logs = {}
for config in model_configs:
    if 'model_dir' in config:
        logs = load_training_logs(config['model_dir'])
        if logs is not None:
            training_logs[config['name']] = logs
            print(f"‚úÖ Loaded logs for {config['display_name']}")

print(f"\n‚úì Loaded logs for {len(training_logs)} models")

### 3.1 Loss Curves

In [None]:
if training_logs:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Training loss
    for name, logs in training_logs.items():
        config = next(c for c in model_configs if c['name'] == name)
        
        if 'train_loss' in logs.columns:
            train_loss = logs['train_loss'].dropna()
            axes[0].plot(
                train_loss.index, 
                train_loss.values,
                label=config['display_name'],
                color=config['color'],
                alpha=0.7,
                linewidth=2
            )
    
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Training Loss')
    axes[0].set_title('Training Loss Curves')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Validation loss
    for name, logs in training_logs.items():
        config = next(c for c in model_configs if c['name'] == name)
        
        if 'val_loss' in logs.columns:
            val_loss = logs['val_loss'].dropna()
            axes[1].plot(
                val_loss.index,
                val_loss.values,
                label=config['display_name'],
                color=config['color'],
                alpha=0.7,
                linewidth=2,
                marker='o',
                markersize=4
            )
    
    axes[1].set_xlabel('Validation Step')
    axes[1].set_ylabel('Validation Loss')
    axes[1].set_title('Validation Loss Curves')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('experiments/st_fine_tuning/results/loss_curves_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n‚úÖ Loss curves plotted")
else:
    print("‚è≥ No training logs available yet")

### 3.2 Training Stability (Loss Variance)

In [None]:
if training_logs:
    stability_metrics = []
    
    for name, logs in training_logs.items():
        config = next(c for c in model_configs if c['name'] == name)
        
        if 'train_loss' in logs.columns:
            train_loss = logs['train_loss'].dropna()
            
            # Compute stability metrics (last 20% of training)
            cutoff = int(0.8 * len(train_loss))
            final_losses = train_loss.iloc[cutoff:]
            
            stability_metrics.append({
                'Model': config['display_name'],
                'Final Loss (Mean)': final_losses.mean(),
                'Loss Std': final_losses.std(),
                'Loss Variance': final_losses.var(),
                'Loss Range': final_losses.max() - final_losses.min(),
            })
    
    stability_df = pd.DataFrame(stability_metrics)
    stability_df = stability_df.sort_values('Loss Variance')
    
    print("\nTraining Stability (Last 20% of Training):")
    print("=" * 80)
    print(stability_df.to_string(index=False))
    print("\nüìä Lower variance = more stable training")
    print("   mHC should show lower variance than non-mHC variants")
else:
    print("‚è≥ No training logs available yet")

## 4. Prediction Quality Comparison

### 4.1 Load Ground Truth

In [None]:
# Load ground truth data
adata_true = ad.read_h5ad("experiments/baseline_analysis/data/burn_sham_baseline_embedded_2000.h5ad")

print(f"Ground Truth Data:")
print(f"  Shape: {adata_true.shape}")
print(f"  Embeddings: {adata_true.obsm['X_state_2000'].shape}")
print(f"  Conditions: {adata_true.obs['condition'].value_counts().to_dict()}")

# Separate burn and sham
burn_true = adata_true[adata_true.obs['condition'] == 'Burn']
sham_true = adata_true[adata_true.obs['condition'] == 'Sham']

print(f"\n  Burn cells: {burn_true.shape[0]}")
print(f"  Sham cells: {sham_true.shape[0]}")

### 4.2 Embedding Correlation

In [None]:
def compute_embedding_correlation(pred_adata, true_adata, pred_key='X_state_2000', true_key='X_state_2000'):
    """Compute correlation between predicted and true embeddings."""
    
    # Match cells by index
    common_cells = pred_adata.obs_names.intersection(true_adata.obs_names)
    
    pred_emb = pred_adata[common_cells].obsm[pred_key]
    true_emb = true_adata[common_cells].obsm[true_key]
    
    # Compute per-dimension correlation
    dim_corrs = []
    for i in range(pred_emb.shape[1]):
        r, _ = pearsonr(pred_emb[:, i], true_emb[:, i])
        dim_corrs.append(r)
    
    # Overall correlation (flattened)
    overall_r, _ = pearsonr(pred_emb.flatten(), true_emb.flatten())
    
    # MSE and R2
    mse = mean_squared_error(true_emb.flatten(), pred_emb.flatten())
    r2 = r2_score(true_emb.flatten(), pred_emb.flatten())
    
    return {
        'mean_correlation': np.mean(dim_corrs),
        'median_correlation': np.median(dim_corrs),
        'overall_correlation': overall_r,
        'mse': mse,
        'r2': r2,
        'n_cells': len(common_cells),
    }

# Compute correlations for all models
correlation_results = []

for name, pred_adata in predictions.items():
    config = next(c for c in model_configs if c['name'] == name)
    
    # Get burn predictions (comparing burn pred vs burn actual)
    burn_pred = pred_adata[pred_adata.obs['condition'] == 'Burn']
    
    if len(burn_pred) > 0 and 'X_state_2000' in burn_pred.obsm:
        metrics = compute_embedding_correlation(burn_pred, burn_true)
        metrics['Model'] = config['display_name']
        correlation_results.append(metrics)

# Display results
if correlation_results:
    corr_df = pd.DataFrame(correlation_results)
    corr_df = corr_df.sort_values('overall_correlation', ascending=False)
    
    print("\nEmbedding Correlation (Burn Predictions vs Burn Ground Truth):")
    print("=" * 80)
    print(corr_df[['Model', 'overall_correlation', 'mean_correlation', 'r2', 'mse']].to_string(index=False))
    print("\nüìä Higher correlation = better predictions")
else:
    print("‚è≥ No predictions available for comparison")

### 4.3 Perturbation Direction Consistency

In [None]:
def compute_perturbation_direction(pred_adata, true_adata, pred_key='X_state_2000', true_key='X_state_2000'):
    """Compute consistency of perturbation direction (sham ‚Üí burn)."""
    
    # Match cells
    burn_pred = pred_adata[pred_adata.obs['condition'] == 'Burn']
    sham_pred = pred_adata[pred_adata.obs['condition'] == 'Sham']
    
    # Compute mean shift (sham ‚Üí burn)
    pred_shift = burn_pred.obsm[pred_key].mean(axis=0) - sham_pred.obsm[pred_key].mean(axis=0)
    true_shift = burn_true.obsm[true_key].mean(axis=0) - sham_true.obsm[true_key].mean(axis=0)
    
    # Compute cosine similarity
    cos_sim = np.dot(pred_shift, true_shift) / (np.linalg.norm(pred_shift) * np.linalg.norm(true_shift))
    
    # Magnitude ratio
    pred_mag = np.linalg.norm(pred_shift)
    true_mag = np.linalg.norm(true_shift)
    mag_ratio = pred_mag / true_mag
    
    return {
        'cosine_similarity': cos_sim,
        'predicted_magnitude': pred_mag,
        'true_magnitude': true_mag,
        'magnitude_ratio': mag_ratio,
    }

# Compute perturbation directions
direction_results = []

for name, pred_adata in predictions.items():
    config = next(c for c in model_configs if c['name'] == name)
    
    if 'X_state_2000' in pred_adata.obsm:
        metrics = compute_perturbation_direction(pred_adata, adata_true)
        metrics['Model'] = config['display_name']
        direction_results.append(metrics)

# Display results
if direction_results:
    dir_df = pd.DataFrame(direction_results)
    dir_df = dir_df.sort_values('cosine_similarity', ascending=False)
    
    print("\nPerturbation Direction Consistency (Sham ‚Üí Burn):")
    print("=" * 80)
    print(dir_df.to_string(index=False))
    print("\nüìä Cosine similarity closer to 1.0 = better direction alignment")
    print("   Magnitude ratio closer to 1.0 = correct perturbation strength")
else:
    print("‚è≥ No predictions available")

## 5. UMAP Visualizations

In [None]:
if predictions:
    n_models = len(predictions)
    fig, axes = plt.subplots(2, (n_models + 1) // 2, figsize=(20, 10))
    axes = axes.flatten()
    
    for i, (name, pred_adata) in enumerate(predictions.items()):
        config = next(c for c in model_configs if c['name'] == name)
        
        # Compute UMAP
        if 'X_state_2000' in pred_adata.obsm:
            sc.pp.neighbors(pred_adata, use_rep='X_state_2000', n_neighbors=15)
            sc.tl.umap(pred_adata)
            
            # Plot by condition
            sc.pl.umap(
                pred_adata,
                color='condition',
                ax=axes[i],
                show=False,
                title=config['display_name'],
            )
    
    # Hide empty axes
    for j in range(i+1, len(axes)):
        axes[j].axis('off')
    
    plt.tight_layout()
    plt.savefig('experiments/st_fine_tuning/results/umap_comparison_all_models.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n‚úÖ UMAP comparison plotted")
else:
    print("‚è≥ No predictions available")

## 6. Summary Report

In [None]:
print("=" * 80)
print("ST MODEL COMPARISON SUMMARY")
print("=" * 80)

print("\n1. Training Stability:")
if training_logs:
    print("   Expected: mHC variants should have lower loss variance")
    if len(stability_df) > 0:
        best_stable = stability_df.iloc[0]['Model']
        print(f"   ‚úÖ Most stable: {best_stable}")
else:
    print("   ‚è≥ Training not completed yet")

print("\n2. Prediction Accuracy:")
if correlation_results:
    best_acc = corr_df.iloc[0]['Model']
    best_r = corr_df.iloc[0]['overall_correlation']
    print(f"   ‚úÖ Best correlation: {best_acc} (r = {best_r:.3f})")
else:
    print("   ‚è≥ Predictions not available yet")

print("\n3. Perturbation Direction:")
if direction_results:
    best_dir = dir_df.iloc[0]['Model']
    best_cos = dir_df.iloc[0]['cosine_similarity']
    print(f"   ‚úÖ Best alignment: {best_dir} (cos = {best_cos:.3f})")
else:
    print("   ‚è≥ Predictions not available yet")

print("\n4. Recommendations:")
print("   - ST-LoRA: Good baseline, parameter-efficient")
print("   - ST-LoRA-mHC: Use if training instability observed")
print("   - ST-LoRA-Velocity: Use if velocity data available")
print("   - ST-LoRA-mHC-Velocity: Best overall (combines all benefits)")

print("\n" + "=" * 80)
print("‚úÖ Comparison complete!")
print("=" * 80)

## 7. Export Results

In [None]:
# Export all metrics to CSV
output_dir = Path("experiments/st_fine_tuning/results")
output_dir.mkdir(parents=True, exist_ok=True)

if correlation_results:
    corr_df.to_csv(output_dir / "embedding_correlations.csv", index=False)
    print("‚úÖ Saved: embedding_correlations.csv")

if direction_results:
    dir_df.to_csv(output_dir / "perturbation_directions.csv", index=False)
    print("‚úÖ Saved: perturbation_directions.csv")

if training_logs:
    stability_df.to_csv(output_dir / "training_stability.csv", index=False)
    print("‚úÖ Saved: training_stability.csv")

print("\n‚úÖ All results exported to experiments/st_fine_tuning/results/")