# Phase 3a: State Transition Data Preparation

This notebook prepares the burn/sham wound healing data for State Transition (ST) model training.

## Approach: Strategy 2 - Baseline SE-600M + ST Training

We will:
1. Load baseline SE-600M embeddings
2. Validate data format and required columns
3. Create TOML configuration for data splits
4. Create YAML configuration for ST training
5. Validate everything is ready for training

**Key Decision**: Using baseline SE-600M embeddings (not LoRA) because cell type preservation is critical for perturbation prediction.

## 1. Environment Setup

In [None]:
import sys
import os
from pathlib import Path
import yaml

import anndata as ad
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

print(f"Working directory: {os.getcwd()}")
print(f"Python version: {sys.version}")

## 2. Load Baseline Embeddings

In [None]:
# Load baseline SE-600M embeddings
data_path = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_baseline_embedded.h5ad"

print(f"Loading data from: {data_path}")
adata = ad.read_h5ad(data_path)

print(f"\n‚úì Loaded AnnData:")
print(f"  Shape: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(f"  Observations: {list(adata.obs.columns)}")
print(f"  Embeddings: {list(adata.obsm.keys())}")

## 3. Validate Required Columns

In [None]:
# Required columns for ST training
required_obs_cols = ['condition', 'timepoint', 'time_days', 'cell_types_simple_short', 'mouse_id']
required_obsm_keys = ['X_state']  # Baseline SE-600M embeddings

print("=" * 80)
print("VALIDATION: Required Columns")
print("=" * 80)

# Check observation columns
print("\n1. Observation Columns:")
all_valid = True
for col in required_obs_cols:
    if col in adata.obs.columns:
        unique_vals = adata.obs[col].unique()
        print(f"   ‚úì '{col}': {len(unique_vals)} unique values")
        print(f"      Values: {list(unique_vals)[:10]}...") if len(unique_vals) > 10 else print(f"      Values: {list(unique_vals)}")
    else:
        print(f"   ‚úó '{col}': MISSING")
        all_valid = False

# Check embedding keys
print("\n2. Embedding Keys:")
for key in required_obsm_keys:
    if key in adata.obsm:
        print(f"   ‚úì '{key}': shape {adata.obsm[key].shape}")
    else:
        print(f"   ‚úó '{key}': MISSING")
        all_valid = False

print("\n" + "=" * 80)
if all_valid:
    print("‚úì All required columns present!")
else:
    print("‚úó Some required columns are missing!")
print("=" * 80)

## 4. Data Distribution Summary

In [None]:
# Summary statistics
print("=" * 80)
print("DATA DISTRIBUTION")
print("=" * 80)

print("\n1. Condition Distribution:")
print(adata.obs['condition'].value_counts())

print("\n2. Timepoint Distribution:")
print(adata.obs['timepoint'].value_counts())

print("\n3. Condition √ó Timepoint Distribution:")
cross_tab = pd.crosstab(adata.obs['condition'], adata.obs['timepoint'])
print(cross_tab)

print("\n4. Cell Type Distribution (top 10):")
print(adata.obs['cell_types_simple_short'].value_counts().head(10))

print("\n5. Mouse ID Distribution:")
print(adata.obs['mouse_id'].value_counts())

In [None]:
# Visualize distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Condition distribution
adata.obs['condition'].value_counts().plot(kind='bar', ax=axes[0, 0], color=['#FF6B6B', '#4ECDC4'])
axes[0, 0].set_title('Condition Distribution', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Condition')
axes[0, 0].set_ylabel('Number of Cells')

# Plot 2: Timepoint distribution
adata.obs['timepoint'].value_counts().sort_index().plot(kind='bar', ax=axes[0, 1], color='#95E1D3')
axes[0, 1].set_title('Timepoint Distribution', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Timepoint')
axes[0, 1].set_ylabel('Number of Cells')

# Plot 3: Condition √ó Timepoint heatmap
import seaborn as sns
sns.heatmap(cross_tab, annot=True, fmt='d', cmap='YlOrRd', ax=axes[1, 0], cbar_kws={'label': 'Cell Count'})
axes[1, 0].set_title('Condition √ó Timepoint', fontsize=12, fontweight='bold')

# Plot 4: Top cell types
top_cell_types = adata.obs['cell_types_simple_short'].value_counts().head(8)
top_cell_types.plot(kind='barh', ax=axes[1, 1], color='#A8E6CF')
axes[1, 1].set_title('Top 8 Cell Types', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Number of Cells')

plt.tight_layout()
plt.savefig('figures/st_data_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úì Saved distribution plot to: figures/st_data_distribution.png")

## 5. Create TOML Configuration

In [None]:
# Create examples directory if it doesn't exist
os.makedirs('examples', exist_ok=True)

# TOML configuration for data splits
toml_content = """# Burn/Sham Wound Healing Dataset Configuration
# For State Transition model training

[datasets]
burn_sham = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/"

[training]
burn_sham = "train"

# Optional: Hold out specific conditions for zeroshot evaluation
# [fewshot]
# [fewshot."burn_sham.macrophage"]
# val = ["day19"]
# test = []
"""

toml_path = "examples/burn_sham.toml"
with open(toml_path, 'w') as f:
    f.write(toml_content)

print(f"‚úì Created TOML config: {toml_path}")
print("\nContents:")
print(toml_content)

## 6. Create YAML Training Configuration

In [None]:
# Create configs directory if it doesn't exist
os.makedirs('configs', exist_ok=True)

# YAML configuration for ST training
config = {
    'data': {
        'toml_config_path': 'examples/burn_sham.toml',
        'embed_key': 'X_state',
        'pert_col': 'condition',
        'control_pert': 'sham',
        'cell_type_key': 'cell_types_simple_short',
        'batch_col': 'mouse_id',
        'batch_size': 16,
        'num_workers': 8,
    },
    'model': {
        'model_class': 'state_transition',
        'input_dim': 2048,
        'output_dim': 2048,
        'hidden_dim': 512,
        'cell_set_len': 256,
        # Timepoint embedding (requires code modification)
        'use_timepoint_embedding': True,
        'timepoint_dim': 128,
    },
    'training': {
        'max_steps': 20000,
        'learning_rate': 1e-4,
        'weight_decay': 0.01,
        'warmup_steps': 1000,
        'gradient_clip_val': 1.0,
        'devices': 2,
        'strategy': 'ddp',
        'log_every_n_steps': 50,
        'val_check_interval': 500,
    },
    'output': {
        'output_dir': '/home/scumpia-mrl/state_models/st_burn_sham',
        'experiment_name': 'st_burn_sham_v1',
        'save_top_k': 3,
        'monitor': 'val_loss',
    }
}

yaml_path = "configs/state_transition_burn_sham.yaml"
with open(yaml_path, 'w') as f:
    yaml.dump(config, f, default_flow_style=False, sort_keys=False)

print(f"‚úì Created YAML config: {yaml_path}")
print("\nContents:")
print(yaml.dump(config, default_flow_style=False, sort_keys=False))

## 7. Validation Summary

In [None]:
print("=" * 80)
print("PREPARATION COMPLETE")
print("=" * 80)

print("\n‚úì Data validated:")
print(f"   - {adata.n_obs:,} cells")
print(f"   - {adata.n_vars:,} genes")
print(f"   - {len(adata.obs['condition'].unique())} conditions: {list(adata.obs['condition'].unique())}")
print(f"   - {len(adata.obs['timepoint'].unique())} timepoints: {list(adata.obs['timepoint'].unique())}")
print(f"   - {len(adata.obs['cell_types_simple_short'].unique())} cell types")
print(f"   - {len(adata.obs['mouse_id'].unique())} mice")

print("\n‚úì Configuration files created:")
print(f"   - TOML: {toml_path}")
print(f"   - YAML: {yaml_path}")

print("\n‚úì Code modifications completed:")
print("   - src/state/tx/models/state_transition.py (timepoint embedding added)")
print("   Lines 200-210: Added timepoint_encoder initialization")
print("   Lines 431-448: Added timepoint embedding in forward pass")

print("\nüìã Next Steps:")
print("   1. Run phase3b_st_model_training.ipynb to start training")
print("\n   2. Monitor training with TensorBoard:")
print("      tensorboard --logdir=/home/scumpia-mrl/state_models/st_burn_sham")

print("\n" + "=" * 80)

## Summary

This notebook has prepared all the data and configuration files needed for State Transition model training:

### ‚úÖ Completed
- Loaded baseline SE-600M embeddings from `burn_sham_baseline_embedded.h5ad`
- Validated all required columns are present
- Analyzed data distribution across conditions, timepoints, and cell types
- Created TOML configuration (`examples/burn_sham.toml`)
- Created YAML training configuration (`configs/state_transition_burn_sham.yaml`)

### ‚úÖ Code Modifications Completed
Modified Arc Institute's State Transition model to add timepoint embedding support:

**File**: `src/state/tx/models/state_transition.py`

1. **Lines 200-210**: Added timepoint encoder initialization in `__init__`:
   ```python
   if kwargs.get("use_timepoint_embedding", False):
       num_timepoints = kwargs.get("num_timepoints", 3)
       self.timepoint_encoder = nn.Embedding(
           num_embeddings=num_timepoints,
           embedding_dim=hidden_dim,
       )
   ```

2. **Lines 431-448**: Added timepoint embedding in forward pass:
   ```python
   if self.timepoint_encoder is not None:
       timepoint_indices = batch.get("timepoint_ids")
       if timepoint_indices is not None:
           timepoint_embeddings = self.timepoint_encoder(timepoint_indices.long())
           seq_input = seq_input + timepoint_embeddings
   ```

### ‚è≠Ô∏è Next Notebook
Continue to **phase3b_st_model_training.ipynb** to run the training.