# Phase 2: LoRA Multi-Covariate Fine-Tuning of SE-600M

This notebook implements parameter-efficient fine-tuning of the STATE SE-600M embedding model using LoRA (Low-Rank Adaptation) adapters with multi-covariate conditioning.

## Approach

1. **Load Pretrained SE-600M**: Load the 600M parameter transformer model
2. **Freeze Base Model**: Keep all pretrained weights frozen
3. **Add LoRA Adapters**: Add low-rank trainable adapters to attention layers
4. **Add Covariate Encoders**: Create embeddings for timepoint + condition
5. **Condition Embeddings**: Combine base embeddings with covariate information
6. **Fine-Tune**: Train only LoRA + covariate parameters (~1-5% of total params)

## Key Differences from CPA Approach (Previous Incorrect Attempt)

- ✅ **LoRA Fine-Tuning**: Works in embedding space, not perturbation prediction space
- ✅ **SE-600M**: Modifies the foundation model itself, not a downstream task model
- ✅ **Parameter Efficient**: Only trains ~1-5% of parameters vs. training entire CPA model
- ✅ **Embedding Conditioning**: Covariates directly influence cell embeddings

## Configuration

- **Base Model**: SE-600M (600M parameters, 16 transformer layers)
- **LoRA Rank**: 16 (low-rank dimension)
- **Covariates**: timepoint (3 categories) + condition (2 categories)
- **Fusion**: Concatenation + MLP (512 → 256 → 2048)
- **Training**: 2x RTX 5000 Ada, DDP, batch size 16

## 1. Environment Setup

In [None]:
import sys
import os
from pathlib import Path
import yaml

import torch
import anndata as ad
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")

## 2. Load and Validate Data

In [None]:
# Load burn/sham dataset
data_path = "/home/scumpia-mrl/Desktop/Sujit/Projects/state-experimentation/burn_sham_data/burn_sham_processed.h5ad"
adata = ad.read_h5ad(data_path)

print(f"Dataset shape: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(f"\nObservations (metadata columns): {adata.obs.columns.tolist()}")
print(f"\nVariables (gene info): {adata.var.columns.tolist()}")

In [None]:
# Validate covariate columns
required_cols = ['condition', 'timepoint', 'cell_types_simple_short', 'mouse_id']

for col in required_cols:
    if col in adata.obs.columns:
        unique_vals = adata.obs[col].unique()
        print(f"✓ '{col}': {len(unique_vals)} unique values")
        print(f"  Values: {unique_vals}")
        print(f"  Distribution:\n{adata.obs[col].value_counts()}\n")
    else:
        print(f"✗ '{col}' NOT FOUND")

In [None]:
# Visualize data distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Condition distribution
adata.obs['condition'].value_counts().plot(kind='bar', ax=axes[0])
axes[0].set_title('Condition Distribution')
axes[0].set_xlabel('Condition')
axes[0].set_ylabel('Number of Cells')

# Timepoint distribution
adata.obs['timepoint'].value_counts().plot(kind='bar', ax=axes[1])
axes[1].set_title('Timepoint Distribution')
axes[1].set_xlabel('Timepoint')
axes[1].set_ylabel('Number of Cells')

# Cell type distribution
cell_type_counts = adata.obs['cell_types_simple_short'].value_counts().head(10)
cell_type_counts.plot(kind='barh', ax=axes[2])
axes[2].set_title('Top 10 Cell Types')
axes[2].set_xlabel('Number of Cells')

plt.tight_layout()
plt.show()

## 3. Load Configuration

In [None]:
# Load LoRA config
config_path = "configs/lora_multicov_config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration:")
print(yaml.dump(config, default_flow_style=False))

## 4. Initialize LoRA Model

In [None]:
from src.state.emb.nn.lora_covariate_model import LoRACovariateStateModel

# Initialize model
print("Loading LoRA model...")
model = LoRACovariateStateModel(
    base_checkpoint_path=config['base_checkpoint'],
    covariate_config=config['covariates'],
    lora_config=config['lora'],
    learning_rate=config['training']['learning_rate'],
    warmup_steps=config['training']['warmup_steps'],
)

# Print trainable parameters
print("\nTrainable Parameters:")
model.print_trainable_parameters()

## 5. Model Architecture Summary

In [None]:
# Print model architecture
print("\n" + "="*80)
print("LoRA Multi-Covariate Model Architecture")
print("="*80)

print("\n1. BASE MODEL (Frozen):")
print(f"   - SE-600M Transformer: 16 layers, 16 heads, 2048 hidden dim")
print(f"   - Token Encoder: Linear(5120 → 2048) + LayerNorm + SiLU")
print(f"   - Transformer Encoder: 16x FlashTransformerEncoderLayer")
print(f"   - Decoder: SkipBlock + Linear(2048 → 2048)")
print(f"   - Status: ❄️ FROZEN (all 600M parameters)")

print("\n2. LoRA ADAPTERS (Trainable):")
print(f"   - Target: Attention Q, V projections")
print(f"   - Rank: {config['lora']['r']}")
print(f"   - Alpha: {config['lora']['lora_alpha']}")
print(f"   - Dropout: {config['lora']['lora_dropout']}")
print(f"   - Applied to: {len(config['lora']['target_modules'])} projection types × 16 layers")

print("\n3. COVARIATE ENCODER (Trainable):")
for cov in config['covariates']['covariates']:
    print(f"   - {cov['name']}: {cov['type']} ({cov.get('num_categories', 'N/A')} categories) → {cov.get('embed_dim', 'N/A')} dim")
print(f"   - Combination MLP: {config['covariates']['combination']['mlp_hidden_dims']} → {config['covariates']['combination']['mlp_output_dim']}")

print("\n4. CONDITIONING PROJECTION (Trainable):")
print(f"   - Input: Concat(base_embedding, covariate_embedding) = 4096 dim")
print(f"   - Output: Conditioned embedding = 2048 dim")
print(f"   - Architecture: Linear + LayerNorm + SiLU")

print("\n" + "="*80)

## 6. Training

In [None]:
# NOTE: Training is resource-intensive and should be run via the training script
# This notebook demonstrates the setup and validation

print("To start training, run:")
print("\n" + "="*80)
print("python train_lora_multicov.py --config configs/lora_multicov_config.yaml")
print("="*80)
print("\nExpected training time: 4-6 hours on 2x RTX 5000 Ada")
print("\nMonitor training with TensorBoard:")
print("tensorboard --logdir=/home/scumpia-mrl/state_models/burn_sham_lora_multicov")

## 7. Load Trained Model (After Training)

In [None]:
# Load trained checkpoint
# checkpoint_path = "/home/scumpia-mrl/state_models/burn_sham_lora_multicov/checkpoints/last.ckpt"
# trained_model = LoRACovariateStateModel.load_from_checkpoint(checkpoint_path)
# trained_model.eval()
# print("Trained model loaded successfully!")

## 8. Extract Covariate-Conditioned Embeddings

In [None]:
# TODO: Implement embedding extraction
# This will:
# 1. Load trained model
# 2. Process each cell with its covariates
# 3. Extract conditioned embeddings
# 4. Save to h5ad file

print("Embedding extraction to be implemented after training")

## 9. Evaluation & Comparison

In [None]:
# Compare baseline vs LoRA embeddings
# - Cell type classification accuracy (kNN)
# - Batch correction (silhouette scores)
# - Temporal coherence
# - UMAP visualization

print("Evaluation to be implemented after training")

## Summary

This notebook sets up LoRA-based multi-covariate fine-tuning of the SE-600M model.

### Key Differences from Previous CPA Approach:

| Aspect | CPA Approach (❌ Wrong) | LoRA Approach (✅ Correct) |
|--------|------------------------|---------------------------|
| **Model** | Trains downstream CPA task model | Fine-tunes SE-600M foundation model |
| **Space** | Perturbation prediction space | Embedding space |
| **Parameters** | Trains all CPA parameters (~20M) | Trains only LoRA + covariates (~1-5%) |
| **Output** | Predicts perturbed gene expression | Produces covariate-conditioned embeddings |
| **Use Case** | Specific to perturbation prediction | General-purpose embeddings for any downstream task |

### Next Steps:

1. Run training script: `python train_lora_multicov.py --config configs/lora_multicov_config.yaml`
2. Monitor with TensorBoard
3. Extract embeddings from trained model
4. Evaluate and compare with baseline
5. Use conditioned embeddings for downstream analysis