# Phase 3b: State Transition Model Training

This notebook trains the Arc Institute State Transition model on burn/sham wound healing data.

## Approach: Baseline SE-600M + Temporal Covariates

We will:
1. Validate data and configuration files
2. Initialize the State Transition model with timepoint embeddings
3. Train the model using PyTorch Lightning
4. Monitor training progress
5. Save best checkpoints

**Key Decision**: Using baseline SE-600M embeddings (96% cell type accuracy) + temporal covariates instead of LoRA embeddings (75% cell type accuracy).

## 1. Environment Setup

In [None]:
import sys
import os
from pathlib import Path
import yaml
import subprocess

import torch
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Check GPU availability
print(f"Working directory: {os.getcwd()}")
print(f"Python version: {sys.version}")
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")

## 2. Validate Prerequisites

In [None]:
# Check that all required files exist
required_files = {
    'Data': '/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_baseline_embedded.h5ad',
    'TOML Config': 'examples/burn_sham.toml',
    'YAML Config': 'configs/state_transition_burn_sham.yaml',
}

print("=" * 80)
print("PREREQUISITE CHECK")
print("=" * 80)

all_exist = True
for name, path in required_files.items():
    exists = os.path.exists(path)
    status = "‚úì" if exists else "‚úó"
    print(f"{status} {name}: {path}")
    if not exists:
        all_exist = False

if all_exist:
    print("\n‚úì All required files exist!")
else:
    print("\n‚úó Some required files are missing. Please run phase3a first.")
    raise FileNotFoundError("Missing required files")

print("=" * 80)

## 3. Load and Inspect Configuration

In [None]:
# Load YAML configuration
config_path = "configs/state_transition_burn_sham.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("=" * 80)
print("TRAINING CONFIGURATION")
print("=" * 80)
print(yaml.dump(config, default_flow_style=False, sort_keys=False))
print("=" * 80)

## 4. Data Validation

In [None]:
# Load data to verify it's ready
data_path = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_baseline_embedded.h5ad"
print(f"Loading data from: {data_path}")
adata = ad.read_h5ad(data_path)

print(f"\n‚úì Loaded AnnData:")
print(f"  Shape: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(f"  Embeddings: {list(adata.obsm.keys())}")
print(f"  Metadata: {list(adata.obs.columns)}")

# Verify embeddings exist
if 'X_state' not in adata.obsm:
    raise ValueError("Baseline SE-600M embeddings (X_state) not found!")

print(f"\n‚úì Baseline embeddings shape: {adata.obsm['X_state'].shape}")

# Check data distribution
print(f"\nData Distribution:")
print(f"  Conditions: {adata.obs['condition'].value_counts().to_dict()}")
print(f"  Timepoints: {adata.obs['timepoint'].value_counts().to_dict()}")
print(f"  Cell types: {len(adata.obs['cell_types_simple_short'].unique())} unique")
print(f"  Mice: {len(adata.obs['mouse_id'].unique())} unique")

## 5. Prepare Training Command

In [None]:
# Construct training command using Hydra-style overrides
# The State CLI uses Hydra for configuration management

output_dir = config['output']['output_dir']
experiment_name = config['output']['experiment_name']

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

training_command = [
    "state", "tx", "train",
    
    # Data configuration
    f"data.kwargs.toml_config_path={config['data']['toml_config_path']}",
    f"data.kwargs.embed_key={config['data']['embed_key']}",
    f"data.kwargs.pert_col={config['data']['pert_col']}",
    f"data.kwargs.control_pert={config['data']['control_pert']}",
    f"data.kwargs.cell_type_key={config['data']['cell_type_key']}",
    f"data.kwargs.batch_col={config['data']['batch_col']}",
    
    # Model configuration
    "model=state",  # Use state transition model
    f"model.kwargs.input_dim={config['model']['input_dim']}",
    f"model.kwargs.output_dim={config['model']['output_dim']}",
    f"model.kwargs.hidden_dim={config['model']['hidden_dim']}",
    f"model.kwargs.cell_set_len={config['model']['cell_set_len']}",
    
    # Timepoint embedding (our modification)
    f"model.kwargs.use_timepoint_embedding={config['model']['use_timepoint_embedding']}",
    f"model.kwargs.num_timepoints=3",  # day10, day14, day19
    
    # Training configuration
    f"training.max_steps={config['training']['max_steps']}",
    f"training.batch_size={config['data']['batch_size']}",
    f"training.learning_rate={config['training']['learning_rate']}",
    f"training.devices={config['training']['devices']}",
    f"training.strategy={config['training']['strategy']}",
    f"training.gradient_clip_val={config['training']['gradient_clip_val']}",
    f"training.val_check_interval={config['training']['val_check_interval']}",
    f"training.log_every_n_steps={config['training']['log_every_n_steps']}",
    
    # Output configuration
    f"output_dir={output_dir}",
    f"name={experiment_name}",
]

print("=" * 80)
print("TRAINING COMMAND")
print("=" * 80)
print(" \\\n  ".join(training_command))
print("=" * 80)

## 6. Training Information

In [None]:
# Estimate training time and resources
num_cells = adata.shape[0]
batch_size = config['data']['batch_size']
max_steps = config['training']['max_steps']
val_interval = config['training']['val_check_interval']
num_gpus = config['training']['devices']

steps_per_epoch = num_cells // (batch_size * num_gpus)
num_epochs = max_steps / steps_per_epoch
num_validations = max_steps // val_interval

print("=" * 80)
print("TRAINING ESTIMATES")
print("=" * 80)
print(f"\nData:")
print(f"  Total cells: {num_cells:,}")
print(f"  Batch size: {batch_size} per GPU")
print(f"  Effective batch size: {batch_size * num_gpus} (across {num_gpus} GPUs)")

print(f"\nTraining:")
print(f"  Max steps: {max_steps:,}")
print(f"  Steps per epoch: ~{steps_per_epoch:,}")
print(f"  Estimated epochs: ~{num_epochs:.1f}")
print(f"  Validation frequency: every {val_interval} steps ({num_validations} times total)")

print(f"\nResources:")
print(f"  GPUs: {num_gpus}x {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
print(f"  Strategy: {config['training']['strategy'].upper()}")

print(f"\nExpected Duration:")
print(f"  Training time: ~4-6 hours (estimated)")
print(f"  Peak GPU memory: ~25-30 GB per GPU (estimated)")

print(f"\nOutputs:")
print(f"  Checkpoints: {output_dir}/{experiment_name}/checkpoints/")
print(f"  Logs: {output_dir}/{experiment_name}/logs/")
print(f"  TensorBoard: tensorboard --logdir={output_dir}/{experiment_name}")

print("=" * 80)

## 7. Start Training

‚ö†Ô∏è **Important**: Training will take 4-6 hours. Monitor progress with TensorBoard in a separate terminal:

```bash
tensorboard --logdir=/home/scumpia-mrl/state_models/st_burn_sham
```

In [None]:
# Option 1: Run training directly (blocking)
# Uncomment to run training in this notebook (will block for 4-6 hours)

# print("Starting training...\n")
# result = subprocess.run(training_command, capture_output=False, text=True)
# if result.returncode == 0:
#     print("\n‚úì Training completed successfully!")
# else:
#     print(f"\n‚úó Training failed with return code {result.returncode}")

print("‚ö†Ô∏è Training not started yet. Please choose one of the options below:")

### Option A: Run in Terminal (Recommended)

Copy and paste this command into a terminal:

```bash
state tx train \
  data.kwargs.toml_config_path=examples/burn_sham.toml \
  data.kwargs.embed_key=X_state \
  data.kwargs.pert_col=condition \
  data.kwargs.control_pert=sham \
  data.kwargs.cell_type_key=cell_types_simple_short \
  data.kwargs.batch_col=mouse_id \
  model=state \
  model.kwargs.input_dim=2048 \
  model.kwargs.output_dim=2048 \
  model.kwargs.hidden_dim=512 \
  model.kwargs.cell_set_len=256 \
  model.kwargs.use_timepoint_embedding=true \
  model.kwargs.num_timepoints=3 \
  training.max_steps=20000 \
  training.batch_size=16 \
  training.learning_rate=0.0001 \
  training.devices=2 \
  training.strategy=ddp \
  training.gradient_clip_val=1.0 \
  training.val_check_interval=500 \
  training.log_every_n_steps=50 \
  output_dir=/home/scumpia-mrl/state_models/st_burn_sham \
  name=st_burn_sham_v1
```

### Option B: Run in Background (tmux/screen)

For long-running training, use tmux or screen:

```bash
# Start tmux session
tmux new -s st_training

# Run training command (same as above)
state tx train ...

# Detach: Ctrl+B, then D
# Reattach: tmux attach -t st_training
```

## 8. Monitor Training Progress

In [None]:
# Check if training has started by looking for output directory
training_dir = f"{output_dir}/{experiment_name}"

if os.path.exists(training_dir):
    print(f"‚úì Training directory exists: {training_dir}\n")
    
    # Check for checkpoints
    checkpoint_dir = f"{training_dir}/checkpoints"
    if os.path.exists(checkpoint_dir):
        checkpoints = sorted(Path(checkpoint_dir).glob("*.ckpt"))
        print(f"Checkpoints found: {len(checkpoints)}")
        for ckpt in checkpoints[-5:]:  # Show last 5
            size_mb = ckpt.stat().st_size / 1e6
            print(f"  - {ckpt.name} ({size_mb:.1f} MB)")
    else:
        print("No checkpoints yet (training may just be starting)")
    
    # Check for logs
    print("\nTo monitor training in real-time:")
    print(f"  tensorboard --logdir={training_dir}")
else:
    print(f"‚úó Training has not started yet.")
    print(f"  Expected directory: {training_dir}")

## 9. Load Training Logs (After Training)

In [None]:
# Parse TensorBoard event files to extract training metrics
from tensorboard.backend.event_processing import event_accumulator

training_dir = f"{output_dir}/{experiment_name}"
version_dir = Path(training_dir) / "lightning_logs" / "version_0"  # Adjust version as needed

if version_dir.exists():
    event_files = list(version_dir.glob("events.out.tfevents.*"))
    
    if event_files:
        print(f"Loading metrics from: {version_dir}\n")
        
        ea = event_accumulator.EventAccumulator(str(version_dir))
        ea.Reload()
        
        # Extract training loss
        train_loss = ea.Scalars('train_loss')
        val_loss = ea.Scalars('val_loss')
        
        # Convert to pandas for easier plotting
        train_df = pd.DataFrame([(s.step, s.value) for s in train_loss], 
                               columns=['step', 'train_loss'])
        val_df = pd.DataFrame([(s.step, s.value) for s in val_loss], 
                             columns=['step', 'val_loss'])
        
        # Plot training curves
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Training loss
        axes[0].plot(train_df['step'], train_df['train_loss'], alpha=0.7)
        axes[0].set_xlabel('Step')
        axes[0].set_ylabel('Training Loss')
        axes[0].set_title('Training Loss Curve')
        axes[0].grid(True, alpha=0.3)
        
        # Validation loss
        axes[1].plot(val_df['step'], val_df['val_loss'], color='orange', marker='o', markersize=3)
        axes[1].set_xlabel('Step')
        axes[1].set_ylabel('Validation Loss')
        axes[1].set_title('Validation Loss Curve')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('figures/st_training_curves.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"\n‚úì Training curves saved to: figures/st_training_curves.png")
        
        # Print final metrics
        print(f"\nFinal Metrics:")
        print(f"  Final training loss: {train_df['train_loss'].iloc[-1]:.4f}")
        print(f"  Final validation loss: {val_df['val_loss'].iloc[-1]:.4f}")
        print(f"  Best validation loss: {val_df['val_loss'].min():.4f}")
    else:
        print("No TensorBoard event files found yet.")
else:
    print(f"Training logs directory not found: {version_dir}")
    print("Training may not have started or logs are in a different location.")

## 10. Identify Best Checkpoint

In [None]:
# Find the checkpoint with the lowest validation loss
checkpoint_dir = f"{output_dir}/{experiment_name}/checkpoints"

if os.path.exists(checkpoint_dir):
    checkpoints = list(Path(checkpoint_dir).glob("*.ckpt"))
    
    if checkpoints:
        print(f"Found {len(checkpoints)} checkpoints:\n")
        
        # Parse checkpoint filenames to extract validation loss
        checkpoint_info = []
        for ckpt in checkpoints:
            name = ckpt.stem
            size_mb = ckpt.stat().st_size / 1e6
            
            # Try to extract val_loss from filename (e.g., "epoch=5-val_loss=0.1234.ckpt")
            if 'val_loss' in name:
                val_loss = float(name.split('val_loss=')[1].split('-')[0])
            else:
                val_loss = float('inf')
            
            checkpoint_info.append({
                'path': ckpt,
                'name': ckpt.name,
                'val_loss': val_loss,
                'size_mb': size_mb
            })
        
        # Sort by validation loss
        checkpoint_info = sorted(checkpoint_info, key=lambda x: x['val_loss'])
        
        # Display top 5 checkpoints
        print("Top 5 checkpoints (by validation loss):\n")
        for i, info in enumerate(checkpoint_info[:5], 1):
            print(f"{i}. {info['name']}")
            print(f"   Val Loss: {info['val_loss']:.4f}")
            print(f"   Size: {info['size_mb']:.1f} MB\n")
        
        # Save best checkpoint path
        best_checkpoint = checkpoint_info[0]['path']
        print(f"‚úì Best checkpoint: {best_checkpoint}")
        print(f"\nUse this checkpoint for evaluation in phase3c_st_evaluation.ipynb")
    else:
        print("No checkpoints found yet.")
else:
    print(f"Checkpoint directory not found: {checkpoint_dir}")

## Summary

This notebook provides the training setup for the State Transition model.

### ‚úÖ Completed
- Validated all prerequisites (data, configs)
- Constructed training command with Hydra overrides
- Provided options for running training (terminal, tmux/screen)
- Set up monitoring tools (TensorBoard, checkpoint tracking)

### üìã Training Details
- **Model**: Arc Institute State Transition with timepoint embeddings
- **Input**: Baseline SE-600M embeddings (2048-dim)
- **Covariates**: Timepoint (day10/14/19), condition (burn/sham), cell type, mouse ID
- **Training**: 20,000 steps, batch size 16, 2 GPUs (DDP)
- **Duration**: ~4-6 hours

### üéØ Modified Components
- **State Transition Model** ([state_transition.py:200-210, 431-448](src/state/tx/models/state_transition.py)): Added timepoint encoder
- **Dataset** ([scgpt_perturbation_dataset.py:175-212](src/state/tx/data/dataset/scgpt_perturbation_dataset.py)): Extracts timepoint_ids

### ‚è≠Ô∏è Next Steps
1. **Start training** using one of the methods above
2. **Monitor progress** with TensorBoard
3. **Wait for completion** (~4-6 hours)
4. **Continue to phase3c_st_evaluation.ipynb** to evaluate predictions