# State Transition Model Comparison: Baseline vs LoRA vs LoRA+mHC

This notebook compares three State Transition (ST) model variants for burn/sham wound healing perturbation prediction:

1. **ST-Tahoe (Baseline)** - Pretrained model from Arc Institute (no fine-tuning)
2. **ST-LoRA** - Fine-tuned with LoRA adapters (parameter-efficient)
3. **ST-LoRA-mHC** - Fine-tuned with LoRA + mHC (manifold-constrained for stable gradients)

## Key Features

- **Input**: SE-600M embeddings (from baseline_analysis)
- **Task**: Predict cellular state changes from sham â†’ burn conditions
- **Innovation**: mHC stabilizes optimal transport loss gradients
- **Efficiency**: LoRA adapts only ~1-5% of parameters

## Expected Outcomes

- LoRA: Faster fine-tuning, lower memory
- mHC: More stable training, better convergence
- Comparison: Which approach best captures wound healing dynamics?

## 1. Environment Setup

In [None]:
import sys
import os
from pathlib import Path
import yaml

import torch
import anndata as ad
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add project root to path
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

## 2. Load SE-600M Baseline Embeddings

Load the baseline SE-600M embeddings that will be used as input for all three ST model variants.

In [None]:
# Load baseline embeddings
data_path = "../baseline_analysis/data/burn_sham_baseline_embedded.h5ad"

print(f"Loading data from: {data_path}")
adata = ad.read_h5ad(data_path)

print(f"\nâœ“ Loaded AnnData:")
print(f"  Shape: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(f"  Embeddings: {list(adata.obsm.keys())}")
print(f"  Observations: {list(adata.obs.columns)}")

# Verify required columns
required = ['condition', 'timepoint', 'cell_types_simple_short', 'mouse_id']
for col in required:
    assert col in adata.obs.columns, f"Missing column: {col}"
    print(f"  âœ“ {col}: {adata.obs[col].nunique()} unique values")

# Verify embeddings
assert 'X_state' in adata.obsm, "Missing X_state embeddings"
print(f"\nâœ“ SE-600M embeddings: {adata.obsm['X_state'].shape}")

## 3. Data Preparation

Prepare the data for ST model training by creating splits and formatting.

In [None]:
# Data distribution summary
print("Data Distribution:")
print("=" * 60)

print("\n1. Condition:")
print(adata.obs['condition'].value_counts())

print("\n2. Timepoint:")
print(adata.obs['timepoint'].value_counts())

print("\n3. Condition Ã— Timepoint:")
cross_tab = pd.crosstab(adata.obs['condition'], adata.obs['timepoint'])
print(cross_tab)

print("\n4. Top 5 Cell Types:")
print(adata.obs['cell_types_simple_short'].value_counts().head(5))

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

adata.obs['condition'].value_counts().plot(kind='bar', ax=axes[0], color=['#FF6B6B', '#4ECDC4'])
axes[0].set_title('Condition Distribution')
axes[0].set_ylabel('Cell Count')

adata.obs['timepoint'].value_counts().sort_index().plot(kind='bar', ax=axes[1], color='#95E1D3')
axes[1].set_title('Timepoint Distribution')
axes[1].set_ylabel('Cell Count')

sns.heatmap(cross_tab, annot=True, fmt='d', cmap='YlOrRd', ax=axes[2])
axes[2].set_title('Condition Ã— Timepoint')

plt.tight_layout()
plt.savefig('results/data_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nâœ“ Saved distribution plot to: results/data_distribution.png")

## 4. Model 1: ST-Tahoe Baseline (Pretrained)

Load and evaluate the pretrained ST-Tahoe model without any fine-tuning.

In [None]:
print("="*80)
print("MODEL 1: ST-Tahoe Baseline (Pretrained)")
print("="*80)

# TODO: Load pretrained ST-Tahoe from HuggingFace
# model_baseline = load_pretrained_st("arcinstitute/ST-Tahoe")

print("\nTo use pretrained ST-Tahoe:")
print("1. Download model from: https://huggingface.co/arcinstitute/ST-Tahoe")
print("2. Run inference using:")
print("   state tx infer --model_dir /path/to/ST-Tahoe --adata burn_sham_baseline_embedded.h5ad")
print("\nFor this experiment, we'll skip baseline and focus on fine-tuned models.")
print("(Baseline results can be added later for comparison)")

## 5. Model 2: ST-LoRA (Fine-tuning with LoRA)

Fine-tune ST model with LoRA adapters only (no mHC).

In [None]:
print("="*80)
print("MODEL 2: ST-LoRA (Fine-tuning with LoRA)")
print("="*80)

# Load configuration
config_path = "configs/lora_config.yaml"
with open(config_path, 'r') as f:
    lora_config = yaml.safe_load(f)

print("\nConfiguration:")
print(f"  LoRA rank: {lora_config['model']['lora']['r']}")
print(f"  LoRA alpha: {lora_config['model']['lora']['alpha']}")
print(f"  Target modules: {lora_config['model']['lora']['target_modules']}")
print(f"  mHC enabled: {lora_config['model']['use_mhc']}")
print(f"  Learning rate: {lora_config['training']['learning_rate']}")
print(f"  Max epochs: {lora_config['training']['max_epochs']}")

print("\n" + "="*80)
print("TRAINING COMMAND")
print("="*80)
print("\nTo train ST-LoRA model, run:")
print("\nstate tx train \\")
print("  data.kwargs.embed_key=X_state \\")
print("  data.kwargs.pert_col=condition \\")
print("  data.kwargs.control_pert=sham \\")
print("  data.kwargs.cell_type_key=cell_types_simple_short \\")
print("  data.kwargs.batch_col=mouse_id \\")
print("  model.kwargs.lora.enable=true \\")
print("  model.kwargs.lora.r=16 \\")
print("  model.kwargs.lora.alpha=32 \\")
print("  model.kwargs.use_mhc=false \\")
print("  training.max_epochs=5 \\")
print("  training.learning_rate=5e-5 \\")
print("  output_dir=/home/scumpia-mrl/state_models/st_lora \\")
print("  name=st_lora_burn_sham")

print("\nâœ“ Configuration saved to:", config_path)

## 6. Model 3: ST-LoRA-mHC (Fine-tuning with LoRA + mHC)

Fine-tune ST model with both LoRA adapters AND mHC for gradient stabilization.

In [None]:
print("="*80)
print("MODEL 3: ST-LoRA-mHC (Fine-tuning with LoRA + mHC)")
print("="*80)

# Load configuration
config_path = "configs/lora_mhc_config.yaml"
with open(config_path, 'r') as f:
    lora_mhc_config = yaml.safe_load(f)

print("\nConfiguration:")
print(f"  LoRA rank: {lora_mhc_config['model']['lora']['r']}")
print(f"  LoRA alpha: {lora_mhc_config['model']['lora']['alpha']}")
print(f"  Target modules: {lora_mhc_config['model']['lora']['target_modules']}")
print(f"  mHC enabled: {lora_mhc_config['model']['use_mhc']}")
print(f"  Sinkhorn iterations: {lora_mhc_config['model']['mhc']['sinkhorn_iters']}")
print(f"  Learning rate: {lora_mhc_config['training']['learning_rate']}")
print(f"  Max epochs: {lora_mhc_config['training']['max_epochs']}")

print("\n" + "="*80)
print("TRAINING COMMAND")
print("="*80)
print("\nTo train ST-LoRA-mHC model, run:")
print("\nstate tx train \\")
print("  data.kwargs.embed_key=X_state \\")
print("  data.kwargs.pert_col=condition \\")
print("  data.kwargs.control_pert=sham \\")
print("  data.kwargs.cell_type_key=cell_types_simple_short \\")
print("  data.kwargs.batch_col=mouse_id \\")
print("  model.kwargs.lora.enable=true \\")
print("  model.kwargs.lora.r=16 \\")
print("  model.kwargs.lora.alpha=32 \\")
print("  model.kwargs.use_mhc=true \\")
print("  model.kwargs.mhc.sinkhorn_iters=10 \\")
print("  training.max_epochs=5 \\")
print("  training.learning_rate=5e-5 \\")
print("  output_dir=/home/scumpia-mrl/state_models/st_lora_mhc \\")
print("  name=st_lora_mhc_burn_sham")

print("\nâœ“ Configuration saved to:", config_path)

## 7. Training Comparison

After training both models, compare training dynamics.

In [None]:
print("="*80)
print("TRAINING MONITORING")
print("="*80)

print("\nTo monitor training in real-time:")
print("\n1. ST-LoRA:")
print("   tensorboard --logdir=/home/scumpia-mrl/state_models/st_lora")
print("\n2. ST-LoRA-mHC:")
print("   tensorboard --logdir=/home/scumpia-mrl/state_models/st_lora_mhc")

print("\n" + "="*80)
print("EXPECTED DIFFERENCES")
print("="*80)

print("\nðŸ“Š Training Stability:")
print("  - ST-LoRA: May show loss spikes with OT loss")
print("  - ST-LoRA-mHC: Smoother loss curves (mHC stabilizes gradients)")

print("\nâš¡ Training Speed:")
print("  - ST-LoRA: Faster (no mHC overhead)")
print("  - ST-LoRA-mHC: Slightly slower (Sinkhorn iterations)")

print("\nðŸ’¾ Memory Usage:")
print("  - Both: Similar (LoRA adapters ~1-5% of model)")
print("  - mHC: Adds per-layer mixing matrices")

print("\nðŸŽ¯ Expected Training Time:")
print("  - ST-LoRA: ~2-3 hours (5 epochs, 2 GPUs)")
print("  - ST-LoRA-mHC: ~3-4 hours (5 epochs, 2 GPUs)")

## 8. Model Evaluation

After training, evaluate and compare all models on perturbation prediction accuracy.

In [None]:
print("="*80)
print("EVALUATION METRICS")
print("="*80)

print("\nAfter training both models, run predictions:")

print("\n1. ST-LoRA:")
print("   state tx predict \\")
print("     --output-dir /home/scumpia-mrl/state_models/st_lora \\")
print("     --checkpoint best.ckpt")

print("\n2. ST-LoRA-mHC:")
print("   state tx predict \\")
print("     --output-dir /home/scumpia-mrl/state_models/st_lora_mhc \\")
print("     --checkpoint best.ckpt")

print("\n" + "="*80)
print("COMPARISON METRICS")
print("="*80)

print("\nWe will compare:")
print("\n1. Perturbation Prediction Accuracy")
print("   - Nearest neighbor distance (burn vs sham)")
print("   - Gene correlation (predicted vs actual)")
print("   - Cell-type-specific response accuracy")

print("\n2. Training Stability")
print("   - Loss curve smoothness")
print("   - Gradient norm stability")
print("   - Convergence speed")

print("\n3. Efficiency")
print("   - Number of trainable parameters")
print("   - Training time per epoch")
print("   - Memory usage")

print("\n4. Biological Interpretability")
print("   - Temporal coherence (day10 â†’ day14 â†’ day19)")
print("   - Wound healing trajectory quality")
print("   - Cell-type-specific perturbation signatures")

## 9. Results Visualization (Post-Training)

Load predictions and create comparison visualizations.

In [None]:
# Placeholder for post-training analysis
print("This cell will be populated after training completes.")
print("\nExpected outputs:")
print("  1. Loss curves comparison (LoRA vs LoRA-mHC)")
print("  2. Prediction accuracy comparison")
print("  3. UMAP of predicted vs actual perturbations")
print("  4. Cell-type-specific perturbation heatmaps")
print("  5. Training stability metrics (gradient norms)")

## 10. Summary and Next Steps

In [None]:
print("="*80)
print("EXPERIMENT SUMMARY")
print("="*80)

print("\nâœ… Completed:")
print("  1. Loaded SE-600M baseline embeddings")
print("  2. Created LoRA fine-tuning configuration")
print("  3. Created LoRA+mHC fine-tuning configuration")
print("  4. Prepared training commands")

print("\nðŸ”„ Next Steps:")
print("  1. Train ST-LoRA model (2-3 hours)")
print("  2. Train ST-LoRA-mHC model (3-4 hours)")
print("  3. Run predictions on test set")
print("  4. Compare results and create visualizations")
print("  5. Document findings")

print("\nðŸ“Š Expected Outcomes:")
print("  - mHC should show more stable training (smoother loss)")
print("  - Both should be parameter-efficient (95%+ frozen)")
print("  - Comparison will reveal best approach for wound healing")

print("\n" + "="*80)