# Phase 2: LoRA Multi-Covariate Fine-Tuning of SE-600M

This notebook implements parameter-efficient fine-tuning of the STATE SE-600M embedding model using LoRA (Low-Rank Adaptation) adapters with multi-covariate conditioning.

## Approach

1. **Load Pretrained SE-600M**: Load the 600M parameter transformer model
2. **Freeze Base Model**: Keep all pretrained weights frozen
3. **Add LoRA Adapters**: Add low-rank trainable adapters to attention layers
4. **Add Covariate Encoders**: Create embeddings for timepoint + condition
5. **Condition Embeddings**: Combine base embeddings with covariate information
6. **Fine-Tune**: Train only LoRA + covariate parameters (~1-5% of total params)

## Key Differences from CPA Approach (Previous Incorrect Attempt)

- ‚úÖ **LoRA Fine-Tuning**: Works in embedding space, not perturbation prediction space
- ‚úÖ **SE-600M**: Modifies the foundation model itself, not a downstream task model
- ‚úÖ **Parameter Efficient**: Only trains ~1-5% of parameters vs. training entire CPA model
- ‚úÖ **Embedding Conditioning**: Covariates directly influence cell embeddings

## Configuration

- **Base Model**: SE-600M (600M parameters, 16 transformer layers)
- **LoRA Rank**: 16 (low-rank dimension)
- **Covariates**: timepoint (3 categories) + condition (2 categories)
- **Fusion**: Concatenation + MLP (512 ‚Üí 256 ‚Üí 2048)
- **Training**: 2x RTX 5000 Ada, DDP, batch size 16

## 1. Environment Setup

In [None]:
import sys
import os
from pathlib import Path
import yaml

import torch
import anndata as ad
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")

## 2. Load and Validate Data

In [None]:
# Load burn/sham dataset
data_path = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_data/burn_sham_processed.h5ad"
adata = ad.read_h5ad(data_path)

print(f"Dataset shape: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(f"\nObservations (metadata columns): {adata.obs.columns.tolist()}")
print(f"\nVariables (gene info): {adata.var.columns.tolist()}")

In [None]:
# Validate covariate columns
required_cols = ['condition', 'timepoint', 'cell_types_simple_short', 'mouse_id']

for col in required_cols:
    if col in adata.obs.columns:
        unique_vals = adata.obs[col].unique()
        print(f"‚úì '{col}': {len(unique_vals)} unique values")
        print(f"  Values: {unique_vals}")
        print(f"  Distribution:\n{adata.obs[col].value_counts()}\n")
    else:
        print(f"‚úó '{col}' NOT FOUND")

In [None]:
# Visualize data distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Condition distribution
adata.obs['condition'].value_counts().plot(kind='bar', ax=axes[0])
axes[0].set_title('Condition Distribution')
axes[0].set_xlabel('Condition')
axes[0].set_ylabel('Number of Cells')

# Timepoint distribution
adata.obs['timepoint'].value_counts().plot(kind='bar', ax=axes[1])
axes[1].set_title('Timepoint Distribution')
axes[1].set_xlabel('Timepoint')
axes[1].set_ylabel('Number of Cells')

# Cell type distribution
cell_type_counts = adata.obs['cell_types_simple_short'].value_counts().head(10)
cell_type_counts.plot(kind='barh', ax=axes[2])
axes[2].set_title('Top 10 Cell Types')
axes[2].set_xlabel('Number of Cells')

plt.tight_layout()
plt.show()

## 3. Load Configuration

In [None]:
# Load LoRA config
config_path = "configs/lora_multicov_config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration:")
print(yaml.dump(config, default_flow_style=False))

## 4. Initialize LoRA Model

In [None]:
from src.state.emb.nn.lora_covariate_model import LoRACovariateStateModel

# Initialize model
print("Loading LoRA model...")
model = LoRACovariateStateModel(
    base_checkpoint_path=config['base_checkpoint'],
    covariate_config=config['covariates'],
    lora_config=config['lora'],
    learning_rate=config['training']['learning_rate'],
    warmup_steps=config['training']['warmup_steps'],
)

# Print trainable parameters
print("\nTrainable Parameters:")
model.print_trainable_parameters()

## 5. Model Architecture Summary

In [None]:
# Print model architecture
print("\n" + "="*80)
print("LoRA Multi-Covariate Model Architecture")
print("="*80)

print("\n1. BASE MODEL (Frozen):")
print(f"   - SE-600M Transformer: 16 layers, 16 heads, 2048 hidden dim")
print(f"   - Token Encoder: Linear(5120 ‚Üí 2048) + LayerNorm + SiLU")
print(f"   - Transformer Encoder: 16x FlashTransformerEncoderLayer")
print(f"   - Decoder: SkipBlock + Linear(2048 ‚Üí 2048)")
print(f"   - Status: ‚ùÑÔ∏è FROZEN (all 600M parameters)")

print("\n2. LoRA ADAPTERS (Trainable):")
print(f"   - Target: Attention Q, V projections")
print(f"   - Rank: {config['lora']['r']}")
print(f"   - Alpha: {config['lora']['lora_alpha']}")
print(f"   - Dropout: {config['lora']['lora_dropout']}")
print(f"   - Applied to: {len(config['lora']['target_modules'])} projection types √ó 16 layers")

print("\n3. COVARIATE ENCODER (Trainable):")
for cov in config['covariates']['covariates']:
    print(f"   - {cov['name']}: {cov['type']} ({cov.get('num_categories', 'N/A')} categories) ‚Üí {cov.get('embed_dim', 'N/A')} dim")
print(f"   - Combination MLP: {config['covariates']['combination']['mlp_hidden_dims']} ‚Üí {config['covariates']['combination']['mlp_output_dim']}")

print("\n4. CONDITIONING PROJECTION (Trainable):")
print(f"   - Input: Concat(base_embedding, covariate_embedding) = 4096 dim")
print(f"   - Output: Conditioned embedding = 2048 dim")
print(f"   - Architecture: Linear + LayerNorm + SiLU")

print("\n" + "="*80)

## 6. Training

In [None]:
# NOTE: Training is resource-intensive and should be run via the training script
# This notebook demonstrates the setup and validation

print("To start training, run:")
print("\n" + "="*80)
print("python train_lora_multicov.py --config configs/lora_multicov_config.yaml")
print("="*80)
print("\nExpected training time: 4-6 hours on 2x RTX 5000 Ada")
print("\nMonitor training with TensorBoard:")
print("tensorboard --logdir=/home/scumpia-mrl/state_models/burn_sham_lora_multicov")

## 7. Load Trained Model (After Training)

In [None]:
# Load best trained checkpoint (lowest validation loss)
checkpoint_path = "/home/scumpia-mrl/state_models/burn_sham_lora_multicov/checkpoints/epoch=09-val_loss=0.5877.ckpt"

print(f"Loading checkpoint: {checkpoint_path}")
trained_model = LoRACovariateStateModel.load_from_checkpoint(
    checkpoint_path,
    base_checkpoint_path=config['base_checkpoint'],
    covariate_config=config['covariates'],
    lora_config=config['lora'],
)
trained_model.eval()
trained_model = trained_model.cuda()  # Move to GPU for faster inference

print("\n‚úì Trained model loaded successfully!")
print(f"  Model device: {next(trained_model.parameters()).device}")
print(f"  Validation loss: 0.5877")

## 8. Extract Covariate-Conditioned Embeddings

In [None]:
# Check if embeddings exist and verify which checkpoint was used
import os
from datetime import datetime

embeddings_path = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_lora_embedded.h5ad"
best_checkpoint_path = "/home/scumpia-mrl/state_models/burn_sham_lora_multicov/checkpoints/epoch=09-val_loss=0.5877.ckpt"

if os.path.exists(embeddings_path):
    # Check timestamps
    emb_time = datetime.fromtimestamp(os.path.getmtime(embeddings_path))
    ckpt_time = datetime.fromtimestamp(os.path.getmtime(best_checkpoint_path))
    
    print(f"Embeddings file: {embeddings_path}")
    print(f"  Created: {emb_time}")
    print(f"\nBest checkpoint: {best_checkpoint_path}")
    print(f"  Created: {ckpt_time}")
    print(f"\n‚ö† Embeddings were created {(emb_time - ckpt_time).days} days AFTER best checkpoint")
    print(f"‚ö† This suggests embeddings may be from the best checkpoint!")
    
    # Load and verify
    print(f"\nLoading embeddings file to verify...")
    adata = ad.read_h5ad(embeddings_path)
    
    print(f"\n‚úì Loaded AnnData:")
    print(f"  Shape: {adata.shape}")
    print(f"  Embeddings in .obsm: {list(adata.obsm.keys())}")
    
    if 'X_lora_conditioned' in adata.obsm:
        print(f"\n‚úì Found 'X_lora_conditioned': {adata.obsm['X_lora_conditioned'].shape}")
        print(f"\nEmbeddings appear to exist. Proceeding with visualization and analysis...")
    else:
        print("\n‚úó No 'X_lora_conditioned' found.")
        print("Need to extract embeddings from best checkpoint.")
        
        # Offer to run extraction
        print("\nTo extract embeddings from best checkpoint, run:")
        print(f"python extract_embeddings.py --config configs/lora_multicov_config.yaml --checkpoint {best_checkpoint_path}")
else:
    print(f"‚úó Embeddings file not found: {embeddings_path}")
    print("\nNeed to extract embeddings from checkpoint.")
    print(f"\nRun:")
    print(f"python extract_embeddings.py --config configs/lora_multicov_config.yaml --checkpoint {best_checkpoint_path}")

## 9. Evaluation & Comparison

In [None]:
# Generate UMAP visualization with LoRA-conditioned embeddings
if 'X_lora_conditioned' not in adata.obsm:
    print("ERROR: No LoRA conditioned embeddings found!")
    print("Please run the embedding extraction first.")
else:
    print("Computing UMAP on LoRA-conditioned embeddings...")
    
    # Compute neighbors and UMAP
    sc.pp.neighbors(adata, use_rep='X_lora_conditioned', n_neighbors=15, random_state=42)
    sc.tl.umap(adata, random_state=42)
    
    print("‚úì UMAP computed successfully!")
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Plot 1: Timepoint
    sc.pl.umap(adata, color='timepoint', ax=axes[0, 0], show=False, title='LoRA Embeddings: Timepoint')
    
    # Plot 2: Condition
    sc.pl.umap(adata, color='condition', ax=axes[0, 1], show=False, title='LoRA Embeddings: Condition')
    
    # Plot 3: Cell Type
    sc.pl.umap(adata, color='cell_types_simple_short', ax=axes[0, 2], show=False, 
               title='LoRA Embeddings: Cell Type', legend_loc='on data', legend_fontsize=6)
    
    # Plot 4: Mouse ID
    sc.pl.umap(adata, color='mouse_id', ax=axes[1, 0], show=False, title='LoRA Embeddings: Mouse ID')
    
    # Plot 5: Combined (timepoint + condition)
    adata.obs['timepoint_condition'] = adata.obs['timepoint'].astype(str) + '_' + adata.obs['condition'].astype(str)
    sc.pl.umap(adata, color='timepoint_condition', ax=axes[1, 1], show=False, 
               title='LoRA Embeddings: Timepoint √ó Condition')
    
    # Plot 6: Number of genes detected
    if 'n_genes' not in adata.obs.columns:
        adata.obs['n_genes'] = (adata.X > 0).sum(axis=1).A1 if hasattr((adata.X > 0).sum(axis=1), 'A1') else (adata.X > 0).sum(axis=1)
    sc.pl.umap(adata, color='n_genes', ax=axes[1, 2], show=False, 
               title='LoRA Embeddings: Gene Count', cmap='viridis')
    
    plt.tight_layout()
    plt.savefig('/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/lora_embeddings_umap_overview.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n‚úì UMAP visualization saved to: lora_embeddings_umap_overview.png")

In [None]:
# Quantitative evaluation: Compare baseline vs LoRA embeddings
if 'adata_baseline' in locals() and baseline_key is not None:
    # Compute evaluation metrics
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score, f1_score
    from scipy.spatial.distance import cdist
    
    def evaluate_embeddings(adata_eval, embedding_key, name):
        """Evaluate embedding quality using multiple metrics."""
        
        print(f"\n{'='*60}")
        print(f"Evaluating: {name}")
        print(f"{'='*60}")
        
        X = adata_eval.obsm[embedding_key]
        
        # 1. Cell Type Classification (kNN)
        print("\n1. Cell Type Classification (10-fold CV):")
        y = adata_eval.obs['cell_types_simple_short'].values
        
        accuracies = []
        f1_scores = []
        for seed in range(10):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=seed, stratify=y
            )
            knn = KNeighborsClassifier(n_neighbors=15)
            knn.fit(X_train, y_train)
            y_pred = knn.predict(X_test)
            
            accuracies.append(accuracy_score(y_test, y_pred))
            f1_scores.append(f1_score(y_test, y_pred, average='weighted'))
        
        print(f"   Accuracy: {np.mean(accuracies):.4f} ¬± {np.std(accuracies):.4f}")
        print(f"   F1 Score: {np.mean(f1_scores):.4f} ¬± {np.std(f1_scores):.4f}")
        
        # 2. Temporal Coherence
        print("\n2. Temporal Coherence (lower = more coherent):")
        timepoints = sorted(adata_eval.obs['timepoint'].unique())
        
        temporal_distances = []
        for i in range(len(timepoints) - 1):
            tp1, tp2 = timepoints[i], timepoints[i+1]
            
            # Get embeddings for each timepoint
            idx1 = adata_eval.obs['timepoint'] == tp1
            idx2 = adata_eval.obs['timepoint'] == tp2
            
            emb1 = X[idx1]
            emb2 = X[idx2]
            
            # Compute average distance between consecutive timepoints
            distances = cdist(emb1, emb2, metric='cosine')
            avg_dist = np.mean(distances)
            temporal_distances.append(avg_dist)
            
            print(f"   {tp1} ‚Üí {tp2}: {avg_dist:.4f}")
        
        avg_temporal_coherence = np.mean(temporal_distances)
        print(f"   Average: {avg_temporal_coherence:.4f}")
        
        # 3. Condition Separation
        print("\n3. Condition Separation (higher = better separation):")
        condition_separations = []
        
        for tp in timepoints:
            tp_mask = adata_eval.obs['timepoint'] == tp
            
            burn_emb = X[(tp_mask) & (adata_eval.obs['condition'] == 'burn')]
            sham_emb = X[(tp_mask) & (adata_eval.obs['condition'] == 'sham')]
            
            if len(burn_emb) > 0 and len(sham_emb) > 0:
                # Average distance between burn and sham at this timepoint
                distances = cdist(burn_emb, sham_emb, metric='cosine')
                avg_dist = np.mean(distances)
                condition_separations.append(avg_dist)
                print(f"   {tp}: burn ‚Üî sham distance = {avg_dist:.4f}")
        
        avg_condition_separation = np.mean(condition_separations)
        print(f"   Average: {avg_condition_separation:.4f}")
        
        # 4. Batch Mixing (Silhouette score)
        print("\n4. Batch Mixing - Mouse ID (lower = better mixing):")
        from sklearn.metrics import silhouette_score
        
        if 'mouse_id' in adata_eval.obs.columns:
            # Sample for speed
            sample_size = min(5000, len(X))
            sample_idx = np.random.choice(len(X), sample_size, replace=False)
            sil_score = silhouette_score(X[sample_idx], adata_eval.obs['mouse_id'].iloc[sample_idx], 
                                        metric='cosine')
            print(f"   Silhouette Score: {sil_score:.4f}")
        else:
            sil_score = np.nan
        
        return {
            'name': name,
            'accuracy': np.mean(accuracies),
            'accuracy_std': np.std(accuracies),
            'f1_score': np.mean(f1_scores),
            'f1_std': np.std(f1_scores),
            'temporal_coherence': avg_temporal_coherence,
            'condition_separation': avg_condition_separation,
            'batch_silhouette': sil_score
        }
    
    # Evaluate both embeddings
    baseline_metrics = evaluate_embeddings(adata_baseline, baseline_key, 'Baseline SE-600M')
    lora_metrics = evaluate_embeddings(adata, 'X_lora_conditioned', 'LoRA Multi-Covariate')
    
    # Create comparison table
    comparison_df = pd.DataFrame([baseline_metrics, lora_metrics])
    comparison_df = comparison_df.set_index('name')
    
    print("\n" + "="*80)
    print("SUMMARY: Baseline vs LoRA Multi-Covariate")
    print("="*80)
    print(comparison_df[['accuracy', 'f1_score', 'temporal_coherence', 'condition_separation', 'batch_silhouette']])
    
    # Compute improvement
    print("\n" + "="*80)
    print("IMPROVEMENT (LoRA vs Baseline)")
    print("="*80)
    
    improvements = {
        'Accuracy': ((lora_metrics['accuracy'] - baseline_metrics['accuracy']) / baseline_metrics['accuracy']) * 100,
        'F1 Score': ((lora_metrics['f1_score'] - baseline_metrics['f1_score']) / baseline_metrics['f1_score']) * 100,
        'Temporal Coherence': ((baseline_metrics['temporal_coherence'] - lora_metrics['temporal_coherence']) / baseline_metrics['temporal_coherence']) * 100,  # Lower is better
        'Condition Separation': ((lora_metrics['condition_separation'] - baseline_metrics['condition_separation']) / baseline_metrics['condition_separation']) * 100,
        'Batch Mixing': ((baseline_metrics['batch_silhouette'] - lora_metrics['batch_silhouette']) / abs(baseline_metrics['batch_silhouette'])) * 100,  # Lower is better
    }
    
    for metric, value in improvements.items():
        print(f"{metric:25s}: {value:+.2f}%")
    
    # Save metrics
    comparison_df.to_csv('/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/lora_vs_baseline_metrics.csv')
    print("\n‚úì Metrics saved to: lora_vs_baseline_metrics.csv")
else:
    print("Skipping quantitative evaluation (baseline data not available)")

In [None]:
# Side-by-side comparison: Baseline vs LoRA UMAPs
if 'adata_baseline' in locals() and baseline_key is not None:
    print("Computing UMAP for baseline embeddings...")
    
    # Compute UMAP for baseline
    sc.pp.neighbors(adata_baseline, use_rep=baseline_key, n_neighbors=15, random_state=42)
    sc.tl.umap(adata_baseline, random_state=42)
    
    # Create comparison figure
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Row 1: Baseline embeddings
    sc.pl.umap(adata_baseline, color='timepoint', ax=axes[0, 0], show=False, 
               title='Baseline SE-600M: Timepoint')
    sc.pl.umap(adata_baseline, color='condition', ax=axes[0, 1], show=False, 
               title='Baseline SE-600M: Condition')
    sc.pl.umap(adata_baseline, color='cell_types_simple_short', ax=axes[0, 2], show=False, 
               title='Baseline SE-600M: Cell Type', legend_loc='on data', legend_fontsize=6)
    
    # Row 2: LoRA embeddings (recompute UMAP for consistency)
    sc.pp.neighbors(adata, use_rep='X_lora_conditioned', n_neighbors=15, random_state=42)
    sc.tl.umap(adata, random_state=42)
    
    sc.pl.umap(adata, color='timepoint', ax=axes[1, 0], show=False, 
               title='LoRA Multi-Cov: Timepoint')
    sc.pl.umap(adata, color='condition', ax=axes[1, 1], show=False, 
               title='LoRA Multi-Cov: Condition')
    sc.pl.umap(adata, color='cell_types_simple_short', ax=axes[1, 2], show=False, 
               title='LoRA Multi-Cov: Cell Type', legend_loc='on data', legend_fontsize=6)
    
    plt.tight_layout()
    plt.savefig('/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/baseline_vs_lora_comparison.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n‚úì Comparison plot saved to: baseline_vs_lora_comparison.png")
else:
    print("Skipping baseline comparison (baseline data not available)")

In [None]:
# Load baseline embeddings for comparison
baseline_path = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_baseline_embedded.h5ad"

if os.path.exists(baseline_path):
    print(f"Loading baseline embeddings from: {baseline_path}")
    adata_baseline = ad.read_h5ad(baseline_path)
    
    print(f"‚úì Baseline data loaded: {adata_baseline.shape}")
    print(f"  Embeddings in .obsm: {list(adata_baseline.obsm.keys())}")
    
    # Check for baseline embedding key
    baseline_key = None
    for key in ['X_state', 'X_state_baseline', 'X_emb']:
        if key in adata_baseline.obsm:
            baseline_key = key
            print(f"\n‚úì Using baseline embeddings: '{baseline_key}' {adata_baseline.obsm[baseline_key].shape}")
            break
    
    if baseline_key is None:
        print("\n‚ö† No baseline embeddings found in standard keys")
else:
    print(f"‚ö† Baseline embeddings file not found at: {baseline_path}")
    print("Will skip baseline comparison")

In [None]:
# Save AnnData with both baseline and LoRA embeddings
output_path = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_data/burn_sham_with_lora_embeddings.h5ad"

# Copy UMAP coordinates to obsm
adata.obsm['X_umap_lora'] = adata.obsm['X_umap'].copy()
adata_baseline.obsm['X_umap_baseline'] = adata_baseline.obsm['X_umap'].copy()

# Merge baseline UMAP into main adata
adata.obsm['X_umap_baseline'] = adata_baseline.obsm['X_umap_baseline']

# Save
adata.write_h5ad(output_path)
print(f"‚úì Saved AnnData with embeddings to: {output_path}")
print(f"\nEmbeddings in adata.obsm:")
for key in adata.obsm.keys():
    print(f"  - {key}: {adata.obsm[key].shape}")

print("\n" + "="*60)
print("ANALYSIS COMPLETE!")
print("="*60)
print("\nGenerated files:")
print("  1. lora_embeddings_umap_overview.png - Comprehensive UMAP visualization")
print("  2. baseline_vs_lora_comparison.png - Side-by-side comparison")
print("  3. lora_vs_baseline_metrics.csv - Quantitative metrics")
print("  4. burn_sham_with_lora_embeddings.h5ad - AnnData with all embeddings")
print("\nBest checkpoint: epoch=09-val_loss=0.5877.ckpt")

## 12. Save Results

In [None]:
# Create summary comparison table
comparison_df = pd.DataFrame([baseline_metrics, lora_metrics])
comparison_df = comparison_df.set_index('name')

print("\n" + "="*60)
print("SUMMARY: Baseline vs LoRA Multi-Covariate")
print("="*60)
print(comparison_df)

# Compute improvement
improvement = {
    'accuracy': ((lora_metrics['accuracy'] - baseline_metrics['accuracy']) / baseline_metrics['accuracy']) * 100,
    'f1_score': ((lora_metrics['f1_score'] - baseline_metrics['f1_score']) / baseline_metrics['f1_score']) * 100,
}

print("\n" + "="*60)
print("IMPROVEMENT (%)")
print("="*60)
for metric, value in improvement.items():
    print(f"{metric}: {value:+.2f}%")

# Save metrics
comparison_df.to_csv('/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/lora_vs_baseline_metrics.csv')
print("\n‚úì Metrics saved to: lora_vs_baseline_metrics.csv")

In [None]:
# Compute evaluation metrics
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from scipy.spatial.distance import cdist

def evaluate_embeddings(adata, embedding_key, name):
    """Evaluate embedding quality using multiple metrics."""
    
    print(f"\n{'='*60}")
    print(f"Evaluating: {name}")
    print(f"{'='*60}")
    
    X = adata.obsm[embedding_key]
    
    # 1. Cell Type Classification (kNN)
    print("\n1. Cell Type Classification (10-fold CV):")
    y = adata.obs['cell_types_simple_short'].values
    
    accuracies = []
    f1_scores = []
    for seed in range(10):
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=seed, stratify=y
        )
        knn = KNeighborsClassifier(n_neighbors=15)
        knn.fit(X_train, y_train)
        y_pred = knn.predict(X_test)
        
        accuracies.append(accuracy_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred, average='weighted'))
    
    print(f"   Accuracy: {np.mean(accuracies):.4f} ¬± {np.std(accuracies):.4f}")
    print(f"   F1 Score: {np.mean(f1_scores):.4f} ¬± {np.std(f1_scores):.4f}")
    
    # 2. Temporal Coherence
    print("\n2. Temporal Coherence:")
    timepoints = sorted(adata.obs['timepoint'].unique())
    
    for i in range(len(timepoints) - 1):
        tp1, tp2 = timepoints[i], timepoints[i+1]
        
        # Get embeddings for each timepoint
        idx1 = adata.obs['timepoint'] == tp1
        idx2 = adata.obs['timepoint'] == tp2
        
        emb1 = X[idx1]
        emb2 = X[idx2]
        
        # Compute average distance between consecutive timepoints
        distances = cdist(emb1, emb2, metric='cosine')
        avg_dist = np.mean(distances)
        
        print(f"   {tp1} ‚Üí {tp2}: {avg_dist:.4f}")
    
    # 3. Condition Separation
    print("\n3. Condition Separation:")
    conditions = sorted(adata.obs['condition'].unique())
    
    for tp in timepoints:
        tp_mask = adata.obs['timepoint'] == tp
        
        burn_emb = X[(tp_mask) & (adata.obs['condition'] == 'burn')]
        sham_emb = X[(tp_mask) & (adata.obs['condition'] == 'sham')]
        
        if len(burn_emb) > 0 and len(sham_emb) > 0:
            # Average distance between burn and sham at this timepoint
            distances = cdist(burn_emb, sham_emb, metric='cosine')
            avg_dist = np.mean(distances)
            print(f"   {tp}: burn ‚Üî sham distance = {avg_dist:.4f}")
    
    # 4. Batch Mixing (Silhouette score)
    print("\n4. Batch Mixing (Mouse ID):")
    from sklearn.metrics import silhouette_score
    
    if 'mouse_id' in adata.obs.columns:
        # Lower silhouette score = better batch mixing
        sil_score = silhouette_score(X, adata.obs['mouse_id'], metric='cosine', sample_size=5000)
        print(f"   Silhouette Score: {sil_score:.4f} (lower = better mixing)")
    
    return {
        'accuracy': np.mean(accuracies),
        'f1_score': np.mean(f1_scores),
        'name': name
    }

# Evaluate both embeddings
baseline_metrics = evaluate_embeddings(adata, 'X_state_baseline', 'Baseline SE-600M')
lora_metrics = evaluate_embeddings(adata, 'X_lora_conditioned', 'LoRA Multi-Covariate')

## 11. Quantitative Evaluation Metrics

In [None]:
# Side-by-side comparison: Baseline vs LoRA UMAPs
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Compute UMAP for baseline embeddings
print("Computing UMAP for baseline embeddings...")
adata_baseline = adata.copy()
sc.pp.neighbors(adata_baseline, use_rep='X_state_baseline', n_neighbors=15, random_state=42)
sc.tl.umap(adata_baseline, random_state=42)

# Row 1: Baseline embeddings
sc.pl.umap(adata_baseline, color='timepoint', ax=axes[0, 0], show=False, 
           title='Baseline: Timepoint')
sc.pl.umap(adata_baseline, color='condition', ax=axes[0, 1], show=False, 
           title='Baseline: Condition')
sc.pl.umap(adata_baseline, color='cell_types_simple_short', ax=axes[0, 2], show=False, 
           title='Baseline: Cell Type', legend_loc='on data', legend_fontsize=6)

# Row 2: LoRA embeddings (already computed)
sc.pl.umap(adata, color='timepoint', ax=axes[1, 0], show=False, 
           title='LoRA: Timepoint')
sc.pl.umap(adata, color='condition', ax=axes[1, 1], show=False, 
           title='LoRA: Condition')
sc.pl.umap(adata, color='cell_types_simple_short', ax=axes[1, 2], show=False, 
           title='LoRA: Cell Type', legend_loc='on data', legend_fontsize=6)

plt.tight_layout()
plt.savefig('/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/baseline_vs_lora_comparison.png', 
            dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úì Comparison plot saved to: baseline_vs_lora_comparison.png")

In [None]:
# Check if baseline embeddings exist, if not extract them
if 'X_state_baseline' not in adata.obsm:
    print("Baseline embeddings not found. Extracting baseline embeddings from pretrained SE-600M...")
    
    # Load base model for comparison
    from src.state.emb.nn.lora_covariate_model import LoRACovariateStateModel
    
    base_model = LoRACovariateStateModel(
        base_checkpoint_path=config['base_checkpoint'],
        covariate_config=config['covariates'],
        lora_config=config['lora'],
    )
    base_model.eval()
    base_model = base_model.cuda()
    
    # Extract baseline embeddings (without LoRA/covariate conditioning)
    baseline_embeddings = []
    
    with torch.no_grad():
        for i in tqdm(range(0, adata.n_obs, batch_size), desc="Extracting baseline"):
            batch_end = min(i + batch_size, adata.n_obs)
            batch_cells = adata[i:batch_end]
            
            if 'X_norm' in batch_cells.layers:
                batch_expr = torch.tensor(batch_cells.layers['X_norm'], dtype=torch.float32).cuda()
            else:
                batch_expr = torch.tensor(batch_cells.X.toarray() if hasattr(batch_cells.X, 'toarray') else batch_cells.X, dtype=torch.float32).cuda()
            
            # Get base embeddings (no covariates)
            base_emb = base_model.model(batch_expr)
            baseline_embeddings.append(base_emb.cpu().numpy())
    
    adata.obsm['X_state_baseline'] = np.vstack(baseline_embeddings)
    print(f"‚úì Baseline embeddings extracted: {adata.obsm['X_state_baseline'].shape}")
else:
    print(f"‚úì Baseline embeddings already exist: {adata.obsm['X_state_baseline'].shape}")

## 10. Compare with Baseline Embeddings

## Summary & Next Steps

This notebook implemented LoRA-based multi-covariate fine-tuning of the SE-600M model.

### ‚úÖ Completed Work

1. **LoRA Multi-Covariate Model**: Successfully trained SE-600M with LoRA adapters conditioned on timepoint and condition
2. **Embedding Extraction**: Generated covariate-conditioned embeddings from best checkpoint (epoch=09-val_loss=0.5877.ckpt)
3. **UMAP Visualization**: Created comprehensive visualizations showing embedding structure
4. **Quantitative Evaluation**: Computed metrics comparing baseline vs LoRA embeddings

### üîç Key Findings

**LoRA Embeddings Show Strong Temporal/Condition Signal BUT Mixed Cell Types**:

From UMAP analysis:
- ‚úÖ **Strong separation** by timepoint (day10, day14, day19)
- ‚úÖ **Clear separation** by condition (burn vs sham)
- ‚ùå **Mixed cell type clustering** (cell types not well-separated within timepoint/condition groups)

**Quantitative Metrics**:
- Cell type accuracy: ~75% (vs 96% baseline)
- Temporal coherence: Improved (lower distances between consecutive timepoints)
- Condition separation: Improved (higher distances between burn/sham)

### üìã Decision: Use Baseline SE-600M Embeddings for State Transition Training

**Rationale**:
- **Goal**: Predict gene perturbation effects in wound healing context
- **Critical requirement**: Cell type identity preservation is essential for perturbation prediction
- **LoRA embeddings**: Optimized temporal/condition signal at the cost of cell type clustering
- **State Transition model**: Needs strong cell type structure to learn cell-type-specific responses

**Chosen Strategy (Strategy 2)**:
Use baseline SE-600M embeddings + add temporal/condition as metadata covariates in ST training

### ‚è≠Ô∏è Next Phase: State Transition Model Training

See [phase3a_st_data_preparation.ipynb](phase3a_st_data_preparation.ipynb) for:
1. Data preparation with baseline embeddings
2. TOML/YAML configuration creation
3. Code modifications to Arc Institute's State Transition model for timepoint embeddings

**Modified Files**:
- [src/state/tx/models/state_transition.py](src/state/tx/models/state_transition.py) (lines 200-210, 431-448)
- [src/state/tx/data/dataset/scgpt_perturbation_dataset.py](src/state/tx/data/dataset/scgpt_perturbation_dataset.py) (lines 175-212)

### üéØ Alternative Strategies (Future Work)

If baseline + ST doesn't meet performance targets, consider:

1. **Strategy 1**: Refine LoRA embeddings with cell type preservation
   - Add cell type as explicit covariate
   - Use contrastive loss to preserve cell type clustering
   - Retrain LoRA model

2. **Strategy 3**: Fine-tune State Transition model with LoRA
   - Treat burn/sham as perturbations for ST training
   - Use LoRA adapters on ST transformer backbone
   - More direct adaptation to wound healing biology