# Unseen BoneMarrow - Leave-One-Out Cross-Validation

This notebook demonstrates leave-one-out (LOO) cross-validation for evaluating LSD's ability to predict cell fates for unseen cell populations.

**Experiment:**
1. Train LSD on all clusters except one
2. Test on the full dataset including the held-out cluster
3. Repeat for each cluster
4. Evaluate fate prediction accuracy for held-out cells

**Data Requirements:**
- BoneMarrow dataset: `../../data/BoneMarrow/normalized_before_low.h5ad`

## Setup

In [None]:
import os
from pathlib import Path
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# Import lsdpy components
from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state

# Set random seed
SEED = 42
set_all_seeds(SEED)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create output directory
model_dir = Path("./loo_models")
model_dir.mkdir(parents=True, exist_ok=True)

## Load Data

In [None]:
# Load BoneMarrow dataset
data_path = "../../data/BoneMarrow/normalized_before_low.h5ad"
main_adata = sc.read_h5ad(data_path)

print(f"Loaded {main_adata.n_obs:,} cells x {main_adata.n_vars:,} genes")
print(f"\nCluster distribution:")
print(main_adata.obs['clusters'].value_counts())

## Configure Leave-One-Out Experiment

In [None]:
# Define clusters to leave out in each fold
CLUSTERS_TO_LEAVE_OUT = [
    'DCs', 'Ery_1', 'HSC_1', 'Mono_1',
    'Precursors', 'HSC_2', 'Mono_2', 'Ery_2'
]

# Prior pseudotime key
PRIOR_TIME_KEY = "palantir_pseudotime"

print(f"Will perform LOO cross-validation for {len(CLUSTERS_TO_LEAVE_OUT)} clusters:")
for cluster in CLUSTERS_TO_LEAVE_OUT:
    n_cells = (main_adata.obs['clusters'] == cluster).sum()
    print(f"  - {cluster}: {n_cells} cells")

## Leave-One-Out Training Loop

In [None]:
# Initialize results DataFrame
results_df = pd.DataFrame(index=main_adata.obs_names)
out_path = Path("loo_cell_fates.csv")

# Training configuration
cfg = LSDConfig()
cfg.optimizer.kl_schedule.af = 2
cfg.walks.batch_size = 256
cfg.walks.path_len = 10
cfg.walks.num_walks = 4096
cfg.walks.random_state = SEED

In [None]:
# Run LOO cross-validation
for excluded in CLUSTERS_TO_LEAVE_OUT:
    print(f"\n{'='*60}")
    print(f"Training with cluster EXCLUDED: {excluded}")
    print(f"{'='*60}")
    
    # Create training set (exclude target cluster)
    train_mask = ~(main_adata.obs['clusters'] == excluded)
    train_adata = main_adata[train_mask].copy()
    
    print(f"Training on {train_adata.n_obs} cells (excluded {(~train_mask).sum()} {excluded} cells)")
    
    # Recompute KNN graph for training data
    sc.pp.pca(train_adata)
    sc.pp.neighbors(train_adata)
    
    # Initialize LSD model
    set_all_seeds(SEED)
    clear_pyro_state()
    
    lsd_model = LSD(
        adata=train_adata,
        config=cfg,
        device=device,
        lib_size_key="librarysize",
        raw_count_key="raw"
    )
    
    # Configure prior and random walks
    lsd_model.set_prior_transition(prior_time_key=PRIOR_TIME_KEY)
    lsd_model.prepare_walks()
    
    # Train model
    clear_pyro_state()
    save_dir = model_dir / f"{excluded}_excluded"
    lsd_model.train(
        num_epochs=200,
        save_dir=str(save_dir),
        save_interval=50
    )
    
    # Test on full dataset (including held-out cluster)
    test_adata = main_adata[:, train_adata.var_names].copy()
    lsd_model.set_adata(test_adata)
    final_adata = lsd_model.get_adata()
    
    # Predict cell fates
    dyn_adata = lsd_model.get_cell_fates(
        final_adata,
        time_range=10,
        batch_size=2700,
        cluster_key="clusters"
    )
    
    # Store fate predictions
    fate_col = f"fate_excluding_{excluded}"
    fate_series = pd.Series(
        dyn_adata.obs["fate"].astype(str).values,
        index=dyn_adata.obs_names,
        name=fate_col
    )
    results_df[fate_col] = fate_series.reindex(results_df.index)
    
    # Save intermediate results
    results_df.to_csv(out_path)
    
    # Report results for held-out cluster
    held_out_mask = main_adata.obs['clusters'] == excluded
    held_out_fates = fate_series[held_out_mask]
    print(f"\nFate predictions for held-out {excluded} cells:")
    print(held_out_fates.value_counts())

print(f"\n{'='*60}")
print("LOO cross-validation complete!")
print(f"Results saved to: {out_path}")
print(f"{'='*60}")

## Analyze LOO Results

In [None]:
# Load results
results_df = pd.read_csv(out_path, index_col=0)

print(f"Results shape: {results_df.shape}")
print(f"\nColumns:")
for col in results_df.columns:
    print(f"  - {col}")

In [None]:
# Analyze accuracy for held-out clusters
accuracy_results = []

for excluded in CLUSTERS_TO_LEAVE_OUT:
    fate_col = f"fate_excluding_{excluded}"
    if fate_col not in results_df.columns:
        continue
        
    # Get held-out cells
    held_out_mask = main_adata.obs['clusters'] == excluded
    held_out_indices = main_adata.obs_names[held_out_mask]
    
    # Get predictions for held-out cells
    predictions = results_df.loc[held_out_indices, fate_col]
    
    # Check if prediction matches true cluster
    correct = (predictions == excluded).sum()
    total = len(predictions)
    accuracy = correct / total if total > 0 else 0
    
    accuracy_results.append({
        'excluded_cluster': excluded,
        'n_cells': total,
        'correct_predictions': correct,
        'accuracy': accuracy
    })
    
    print(f"{excluded}: {accuracy:.1%} accuracy ({correct}/{total} correct)")

accuracy_df = pd.DataFrame(accuracy_results)

In [None]:
# Create summary heatmap of fate predictions for held-out clusters
confusion_data = []

for excluded in CLUSTERS_TO_LEAVE_OUT:
    fate_col = f"fate_excluding_{excluded}"
    if fate_col not in results_df.columns:
        continue
    
    held_out_mask = main_adata.obs['clusters'] == excluded
    held_out_indices = main_adata.obs_names[held_out_mask]
    predictions = results_df.loc[held_out_indices, fate_col]
    
    # Count predictions
    for pred_cluster in predictions.unique():
        count = (predictions == pred_cluster).sum()
        confusion_data.append({
            'held_out': excluded,
            'predicted': pred_cluster,
            'count': count,
            'proportion': count / len(predictions)
        })

confusion_df = pd.DataFrame(confusion_data)

# Pivot for heatmap
if len(confusion_df) > 0:
    confusion_matrix = confusion_df.pivot_table(
        index='held_out',
        columns='predicted',
        values='proportion',
        fill_value=0
    )
    
    plt.figure(figsize=(12, 8))
    sns.heatmap(
        confusion_matrix,
        annot=True,
        fmt='.2f',
        cmap='Blues',
        linewidths=0.5
    )
    plt.title('LOO Fate Prediction: Held-out Cluster vs Predicted Fate', fontsize=14)
    plt.xlabel('Predicted Fate', fontsize=12)
    plt.ylabel('Held-out Cluster', fontsize=12)
    plt.tight_layout()
    plt.savefig(model_dir / 'loo_confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Plot accuracy by cluster
if len(accuracy_df) > 0:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars = ax.bar(
        accuracy_df['excluded_cluster'],
        accuracy_df['accuracy'],
        color='steelblue',
        edgecolor='black'
    )
    
    ax.axhline(y=accuracy_df['accuracy'].mean(), color='red', linestyle='--', 
               label=f'Mean: {accuracy_df["accuracy"].mean():.1%}')
    
    ax.set_xlabel('Held-out Cluster', fontsize=12)
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.set_title('LOO Cross-Validation: Fate Prediction Accuracy', fontsize=14)
    ax.set_ylim(0, 1)
    ax.legend()
    
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(model_dir / 'loo_accuracy_by_cluster.png', dpi=150, bbox_inches='tight')
    plt.show()

## Save Summary Results

In [None]:
import json

# Save summary
summary = {
    "experiment": "Leave-One-Out Cross-Validation",
    "dataset": "BoneMarrow",
    "n_cells": int(main_adata.n_obs),
    "n_folds": len(CLUSTERS_TO_LEAVE_OUT),
    "clusters_tested": CLUSTERS_TO_LEAVE_OUT,
    "mean_accuracy": float(accuracy_df['accuracy'].mean()) if len(accuracy_df) > 0 else None,
    "per_cluster_accuracy": accuracy_df.to_dict('records') if len(accuracy_df) > 0 else []
}

with open(model_dir / "loo_summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(f"Summary saved to {model_dir / 'loo_summary.json'}")
print(f"\n{'='*60}")
print("LOO Cross-Validation Complete!")
print(f"Mean accuracy across all folds: {summary['mean_accuracy']:.1%}" if summary['mean_accuracy'] else "N/A")
print(f"{'='*60}")