# LSDpy Quickstart Guide

This notebook demonstrates the basic workflow for using LSDpy to infer cell differentiation trajectories from single-cell RNA-seq data.

## Installation

If you haven't installed LSDpy yet:

```bash
pip install lsdpy
```

Or install from source:

```bash
pip install -e .
```

In [None]:
# Import required packages
import numpy as np
import torch
import scanpy as sc
import matplotlib.pyplot as plt

# Import LSDpy
from sclsd import (
    LSD,
    LSDConfig,
    prepare_data_dict,
    set_all_seeds,
)

# Set random seed for reproducibility
SEED = 42
set_all_seeds(SEED)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Load and Preprocess Data

LSDpy works with AnnData objects. Here we'll load a sample dataset.

In [None]:
# Load your data (replace with your actual data path)
# adata = sc.read("path/to/your/data.h5ad")

# For demonstration, create synthetic data
np.random.seed(SEED)
n_cells = 1000
n_genes = 500

# Generate synthetic count data
counts = np.random.negative_binomial(n=5, p=0.3, size=(n_cells, n_genes))
adata = sc.AnnData(X=counts.astype(np.float32))
adata.var_names = [f"Gene_{i}" for i in range(n_genes)]
adata.obs_names = [f"Cell_{i}" for i in range(n_cells)]

# Add cluster annotations
adata.obs["clusters"] = np.random.choice(["Stem", "Prog", "Mature"], size=n_cells)
adata.obs["clusters"] = adata.obs["clusters"].astype("category")

print(f"Data shape: {adata.shape}")
print(f"Clusters: {adata.obs['clusters'].value_counts().to_dict()}")

In [None]:
# Prepare data dictionary (preprocessing)
data_dict = prepare_data_dict(
    adata,
    n_top_genes=500,      # Number of HVGs to use
    normalize=True,        # Normalize counts
    log=True,              # Log transform
    n_pcs=50,              # Number of PCs
    n_neighbors=15,        # Neighbors for graph
)

# Add raw layer for LSD training
data_dict["adata"].layers["raw"] = counts[:, :data_dict["adata"].n_vars]

print(f"Processed data shape: {data_dict['adata'].shape}")

## 2. Compute Prior Pseudotime

LSD requires a prior pseudotime to guide initial random walks. We can use diffusion pseudotime or other methods.

In [None]:
# Compute diffusion pseudotime (if you have a root cell/cluster)
adata = data_dict["adata"]

# Find root cell (e.g., in Stem cluster)
stem_cells = np.where(adata.obs["clusters"] == "Stem")[0]
root_cell = stem_cells[0] if len(stem_cells) > 0 else 0

# Compute diffusion pseudotime
sc.tl.diffmap(adata)
adata.uns["iroot"] = root_cell
sc.tl.dpt(adata)

# Rename to prior_time for LSD
adata.obs["prior_time"] = adata.obs["dpt_pseudotime"]

# Update data_dict
data_dict["adata"] = adata

print(f"Prior pseudotime range: [{adata.obs['prior_time'].min():.3f}, {adata.obs['prior_time'].max():.3f}]")

## 3. Configure and Create LSD Model

In [None]:
# Configure the model
cfg = LSDConfig()

# Model architecture
cfg.model.z_dim = 10       # Latent cell state dimension
cfg.model.B_dim = 2        # Differentiation state dimension
cfg.model.V_coeff = 0.0    # Potential regularization

# Random walks
cfg.walks.path_len = 30     # Steps per walk
cfg.walks.num_walks = 5000  # Number of training walks
cfg.walks.batch_size = 128  # Batch size

# Optimizer
cfg.optimizer.adam.lr = 1e-3
cfg.optimizer.adam.T_0 = 30

# KL annealing
cfg.optimizer.kl_schedule.min_af = 0.0
cfg.optimizer.kl_schedule.max_af = 1.0
cfg.optimizer.kl_schedule.max_epoch = 30

print("Configuration:")
print(f"  z_dim: {cfg.model.z_dim}")
print(f"  B_dim: {cfg.model.B_dim}")
print(f"  path_len: {cfg.walks.path_len}")
print(f"  num_walks: {cfg.walks.num_walks}")

In [None]:
# Create the LSD model
set_all_seeds(SEED)  # Ensure reproducibility

lsd = LSD(
    data_dict["adata"],
    cfg,
    device=device,
    lib_size_key="librarysize",
    raw_count_key="raw",
)

print("LSD model created successfully")

## 4. Set Prior Transition and Generate Walks

In [None]:
# Set prior transition matrix based on pseudotime
lsd.set_prior_transition(prior_time_key="prior_time")

# Generate random walks for training
lsd.prepare_walks(n_trajectories=cfg.walks.num_walks)

print(f"Generated {lsd.walks.shape[0]} random walks")
print(f"Walk shape: {lsd.walks.shape}")

## 5. Train the Model

In [None]:
# Train the model
lsd.train(
    num_epochs=50,          # Number of training epochs
    save_dir=None,          # Set to a path to save checkpoints
    save_interval=25,       # Save every N epochs
    plot_loss=True,         # Plot loss curves
    random_state=SEED,      # Random seed for reproducibility
)

## 6. Extract Results

In [None]:
# Get annotated AnnData with LSD results
result = lsd.get_adata()

print("Results added to AnnData:")
print(f"  - obs['lsd_pseudotime']: LSD-inferred pseudotime")
print(f"  - obs['potential']: Waddington potential values")
print(f"  - obs['entropy']: Cell state entropy")
print(f"  - obsm['cell_rep']: Latent cell state (z)")
print(f"  - obsm['diff_rep']: Differentiation state (B)")
print(f"  - obsp['transitions']: Cell-cell transition probabilities")

In [None]:
# Visualize pseudotime
sc.pl.umap(result, color=["lsd_pseudotime", "potential", "clusters"], ncols=3)

## 7. Cell Fate Prediction

In [None]:
# Predict cell fates by propagating through the potential landscape
fate_result = lsd.get_cell_fates(
    adata=result,
    time_range=5.0,         # Time range for propagation
    dt=0.5,                 # Time step
    cluster_key="clusters", # Cluster annotations
    return_paths=True,      # Store full trajectories
)

print(f"Predicted fates: {fate_result.obs['fate'].value_counts().to_dict()}")

In [None]:
# Visualize fates
sc.pl.umap(fate_result, color=["clusters", "fate"], ncols=2)

## 8. Save and Load Model

In [None]:
# Save model
# lsd.save(dir_path="./checkpoints", file_name="lsd_model.pth")

# Load model (create new LSD instance first)
# lsd_loaded = LSD(data_dict["adata"], cfg, device=device)
# lsd_loaded.load(dir_path="./checkpoints", file_name="lsd_model.pth")

## Summary

This notebook demonstrated the basic LSDpy workflow:

1. **Data Preparation**: Load and preprocess single-cell data
2. **Prior Pseudotime**: Compute initial pseudotime estimates
3. **Configuration**: Set model hyperparameters
4. **Random Walks**: Generate training trajectories
5. **Training**: Train the neural ODE model
6. **Results**: Extract pseudotime, potential, and latent representations
7. **Cell Fates**: Predict terminal cell states

For more advanced usage, see the other tutorial notebooks.