# Pancreas Dataset - LSD Training

This notebook demonstrates training the Latent State Dynamics (LSD) model on the Pancreas endocrinogenesis dataset.

**Dataset Information:**
- ~16,822 cells
- ~5,000 highly variable genes
- Complex multi-level phylogeny from ductal cells to endocrine cell types
- Expected CBDir (Cross-Boundary Direction Correctness): ~0.487

## Setup

In [None]:
import os
import scanpy as sc
import numpy as np
import torch

# Import lsdpy components
from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state

# Set random seed for reproducibility
SEED = 42
clear_pyro_state()
set_all_seeds(SEED)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Data

The Pancreas dataset should be preprocessed with:
- Log-normalized counts in `adata.X`
- Raw counts in `adata.layers['raw']`
- Library size in `adata.obs['librarysize']`
- Prior pseudotime in `adata.obs['prior_time']`
- Cell cluster annotations in `adata.obs['clusters']`
- Computed neighbors graph

In [None]:
# Path to preprocessed data
# Update this path to your data location
data_path = "../../data/Pancreas/adata_prior_time.h5ad"

# Load data
adata = sc.read(data_path)
print(f"Loaded {adata.n_obs} cells x {adata.n_vars} genes")
print(f"Clusters: {adata.obs['clusters'].unique().tolist()}")

## Visualize Data

In [None]:
# Plot UMAP colored by clusters
sc.pl.umap(adata, color='clusters', title='Pancreas - Cell Types')

In [None]:
# Plot UMAP colored by prior pseudotime
sc.pl.umap(adata, color='prior_time', title='Pancreas - Prior Pseudotime')

## Configure LSD Model

Key hyperparameters for Pancreas:
- `V_coeff = 5e-3`: Potential regularization coefficient
- `kl_schedule.af = 3`: KL annealing factor
- `lr = 2e-3`: Learning rate
- `batch_size = 256`: Batch size for training
- `path_len = 50`: Length of random walks

In [None]:
# Create configuration
cfg = LSDConfig()

# Model parameters
cfg.model.V_coeff = 5e-3
cfg.model.z_dim = 10  # Latent cell state dimension
cfg.model.B_dim = 2   # Differentiation state dimension

# Optimizer parameters
cfg.optimizer.kl_schedule.af = 3
cfg.optimizer.adam.lr = 2e-3

# Walk parameters
cfg.walks.batch_size = 256
cfg.walks.path_len = 50
cfg.walks.num_walks = 4096
cfg.walks.random_state = SEED

print("Configuration:")
print(f"  V_coeff: {cfg.model.V_coeff}")
print(f"  z_dim: {cfg.model.z_dim}")
print(f"  B_dim: {cfg.model.B_dim}")
print(f"  KL annealing factor: {cfg.optimizer.kl_schedule.af}")
print(f"  Learning rate: {cfg.optimizer.adam.lr}")
print(f"  Batch size: {cfg.walks.batch_size}")
print(f"  Path length: {cfg.walks.path_len}")
print(f"  Number of walks: {cfg.walks.num_walks}")

## Define Phylogeny

The pancreas phylogeny represents the differentiation hierarchy from proliferating ductal cells to mature endocrine cells:

```
Prlf. Ductal -> Ductal -> Ngn3 low -> Ngn3 high -> Fev+ -> Fev+ Alpha -> Alpha
                                                      |-> Fev+ Beta -> Beta
                                               |-> Fev+ Delta -> Delta
                                               |-> Epsilon
```

In [None]:
# Define phylogenetic tree
phylogeny = {
    "Prlf. Ductal": ["Ductal"],
    "Ductal": ["Ngn3 low"],
    "Ngn3 low": ["Ngn3 high"],
    "Ngn3 high": ["Fev+", "Fev+ Delta", "Epsilon"],
    "Fev+": ["Fev+ Alpha", "Fev+ Beta"],
    "Fev+ Delta": ["Delta"],
    "Fev+ Alpha": ["Alpha"],
    "Fev+ Beta": ["Beta"],
    "Epsilon": [],
    "Alpha": [],
    "Beta": [],
    "Delta": [],
}

print("Phylogeny defined with", len(phylogeny), "nodes")

## Initialize LSD Model

In [None]:
# Initialize LSD
lsd = LSD(adata, cfg, device=device)

# Set phylogeny constraints
lsd.set_phylogeny(phylogeny, cluster_key="clusters")

# Set prior transition matrix from pseudotime and phylogeny
lsd.set_prior_transition(prior_time_key="prior_time")

print(f"LSD model initialized")
print(f"  Cells after phylogeny filtering: {lsd.adata.n_obs}")

## Prepare Random Walks

In [None]:
# Generate random walks on the cell transition graph
lsd.prepare_walks()

print(f"Generated {lsd.walks.shape[0]} walks of length {lsd.walks.shape[1]}")

## Train Model

Training will run for 100 epochs with checkpoints saved every 25 epochs.

In [None]:
# Create output directory for model checkpoints
model_dir = "./pancreas_model"
os.makedirs(model_dir, exist_ok=True)

# Train the model
lsd.train(
    num_epochs=100,
    save_dir=model_dir,
    save_interval=25,
    random_state=SEED,
)

## Extract Results

After training, extract the learned representations and inferred pseudotime.

In [None]:
# Get annotated AnnData with results
result = lsd.get_adata()

print("Results added to AnnData:")
print(f"  adata.obs['lsd_pseudotime']: Inferred pseudotime")
print(f"  adata.obs['potential']: Waddington potential")
print(f"  adata.obs['entropy']: Differentiation entropy")
print(f"  adata.obsm['X_cell_state']: Latent cell state (z_loc)")
print(f"  adata.obsm['X_diff_state']: Differentiation state (B_loc)")
print(f"  adata.obsp['transitions']: Transition probability matrix")

## Visualize Results

In [None]:
# Plot LSD pseudotime
sc.pl.umap(result, color='lsd_pseudotime', title='LSD Pseudotime')

In [None]:
# Plot potential landscape
sc.pl.umap(result, color='potential', title='Waddington Potential')

In [None]:
# Plot differentiation entropy
sc.pl.umap(result, color='entropy', title='Differentiation Entropy')

## Save Results

In [None]:
# Save the final model
lsd.save(dir_path=model_dir, file_name="lsd_model_final.pth")

# Save the result AnnData
result.write(os.path.join(model_dir, "result_adata.h5ad"))

print(f"Model and results saved to {model_dir}")

## Next Steps

Proceed to the `postprocessing.ipynb` notebook to:
- Project velocity onto UMAP using CellRank
- Compute Cross-Boundary Direction Correctness (CBDir) scores
- Analyze cell fates and trajectories
- Create streamline visualizations