# LSD Tutorial: Prior Pseudotime Training

This notebook demonstrates how to train an LSD model on a dataset with a prior pseudotime. It covers data preprocessing, LSD model initialization, random-walk generation, and model training.

**Note:** This notebook downloads the bonemarrow dataset via scVelo, making it self-contained for testing.

In [None]:
import os
import scanpy as sc
import scvelo as scv
from sclsd import LSD, LSDConfig
from sclsd.utils import plot_random_walks
import pyro
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.sparse as sp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Read the dataset (downloads automatically)
adata = scv.datasets.bonemarrow()

# Store the raw counts
adata.layers["raw"] = adata.X.copy()

# Preprocess
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

# Select HVGs
sc.pp.filter_genes_dispersion(adata, n_top_genes=2000)
sc.pp.filter_cells(adata, min_counts=50)

# Normalize raw counts to target sum 1e4
adata.X = adata.layers["raw"].copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# Get the library size
adata.obs["librarysize"] = adata.layers['raw'].sum(axis=1)

# KNN graph
sc.pp.pca(adata)
sc.pp.neighbors(adata)

In [None]:
# Inspect the dataset
sc.pl.tsne(adata, color=["clusters", "palantir_pseudotime"])

In [None]:
# Configure and initialize LSD model
cfg = LSDConfig()

# Set the hyperparameters - EXACT values from original tutorial
cfg.walks.num_walks = 1024
cfg.walks.path_len = 20
cfg.optimizer.adam.lr = 2e-3

lsd = LSD(adata, cfg, raw_count_key="raw", lib_size_key="librarysize", device=device)

In [None]:
# Set up phylogeny - prior knowledge about cell hierarchy
# This step is not mandatory for model training
# In bone marrow case we have:
phylogeny = {
    "HSC_1":       ["HSC_2", "Ery_1", "Mega"],
    "Ery_1":       ["Ery_2"],
    "HSC_2":       ["Precursors"],
    "Precursors":  ["Mono_1", "Mono_2", "DCs"],
    "Ery_2":       [],  # Terminal states
    "DCs":         [],
    "Mono_1":      [],
    "Mono_2":      [],
    "CLP":         []
}

# This information refines the random walk generation process
lsd.set_phylogeny(phylogeny, cluster_key="clusters")

In [None]:
# Generate and visualize random walks
# First step: derive a cell-by-cell transition matrix from pseudotime
lsd.set_prior_transition(prior_time_key="palantir_pseudotime")

# In case you have a precomputed transition matrix (e.g. adata.obsp["transition_matrix"]):
# lsd.set_prior_transition(prior_transition=adata.obsp["transition_matrix"])

# Second step: generate the random walks
lsd.prepare_walks()

# Third step: visualize random walks for sanity check
plot_random_walks(lsd.adata, lsd.walks, "X_tsne")

In [None]:
# Train LSD model
pyro.set_rng_seed(42)
num_epochs = 120

# Model save directory
save_dir = "../../models/tutorial_model"
os.makedirs(save_dir, exist_ok=True)

# Train the model
lsd.train(
    num_epochs=num_epochs,
    save_dir=save_dir,
    save_interval=50,  # Save every 50 epochs
    random_state=42    # Random seed
)

In [None]:
# Get the LSD outputs
final_adata = lsd.get_adata()

# Plot the outputs
colors = ["clusters", "potential", "lsd_pseudotime", "entropy"]
sc.pl.embedding(final_adata, color=colors, basis="X_tsne")

# Visualize differentiation state
sc.pl.embedding(final_adata, color=colors, basis="X_diff_state")

# Inspect flow fields
lsd.stream_lines(embedding="X_tsne", size=50)
lsd.stream_lines(embedding="X_diff_state", size=50)

In [None]:
# Using LSD to predict cell fates
dyn_adata = lsd.get_cell_fates(final_adata, time_range=10, cluster_key="clusters")

In [None]:
# Visualize cell fates
dyn_adata.uns["fate_colors"] = dyn_adata.uns["clusters_colors"]
sc.pl.embedding(dyn_adata, color="fate", basis="X_tsne")
sc.pl.embedding(dyn_adata, color="fate", basis="X_diff_state")

In [None]:
# Gene perturbation
# For gene perturbation you need to set three parameters:
# 1- name of gene, 2- number of perturbation iterations, 3- level of perturbation (e.g. for KO level=0)

# Prepare adata
x = torch.from_numpy(dyn_adata.X.toarray())
idx = np.where(dyn_adata.obs['clusters'].isin(["HSC_1", "HSC_2"]))[0]
IC = x[idx].to(device)

# We KO GATA2 only in HSC cells
# See https://www.nature.com/articles/s41587-019-0068-4 for role of GATA2 as a driver of erythroid lineage
# Perturbation analysis returns predicted cell fate in perturbed and unperturbed scenarios
perturbed_fates, unperturbed_fates = lsd.perturb(
    adata=dyn_adata,
    x=IC,
    gene_name="GATA2",
    pertubation_level=0,
    max_perturbations=10,
    cluster_key="clusters",
    batch_size=4096
)

In [None]:
# Show the predicted change in frequency of each fate after GATA2 KO
df = pd.DataFrame({
    "cell_id": dyn_adata.obs_names[idx],
    "unperturbed_fate": np.asarray(unperturbed_fates, dtype=object),
    "perturbed_fate": np.asarray(perturbed_fates, dtype=object),
    "lsd_pseudotime": dyn_adata.obs["lsd_pseudotime"].iloc[idx].values,
    "cluster": dyn_adata.obs["clusters"].iloc[idx].astype(str).values,
})

# Fate-change indicator
df["changed"] = df["unperturbed_fate"] != df["perturbed_fate"]

# Counts before and after perturbation
fate_counts = (
    pd.concat([
        df["unperturbed_fate"].value_counts().rename("unperturbed"),
        df["perturbed_fate"].value_counts().rename("perturbed"),
    ], axis=1)
    .fillna(0)
)

# Normalize to fractions
fate_freq = fate_counts.div(fate_counts.sum(axis=0), axis=1)

# Net effect of perturbation
fate_freq["delta"] = fate_freq["perturbed"] - fate_freq["unperturbed"]

fate_freq.sort_values("delta", ascending=False)
promoted = fate_freq[fate_freq["delta"] > 0].sort_values("delta", ascending=False)
inhibited = fate_freq[fate_freq["delta"] < 0].sort_values("delta")

print("Promoted fates:")
print(promoted[["delta"]])
print("\nInhibited fates:")
print(inhibited[["delta"]])

plt.figure(figsize=(6, 4))
fate_freq["delta"].sort_values().plot.barh(color="steelblue")
plt.axvline(0, color="k", lw=1)
plt.xlabel("Δ fate frequency (perturbed − unperturbed)")
plt.title("Fates promoted / inhibited by GATA2 perturbation")
plt.tight_layout()
plt.show()

In [None]:
# Expression dynamics of GATA2
# --- Settings ---
gene = "GATA2"
pt_key = "lsd_pseudotime"
cluster_key = "fate"  # change if your cluster labels live in a different obs column
n_bins = 20

# Map fine clusters -> coarse lineage
lineage_map = {
    **{k: "erythroid" for k in ["Ery_1", "Ery_2", "Mega"]},
    **{k: "myeloid"   for k in ["Precursors", "Mono_1", "Mono_2", "DCs"]},
    **{k: "lymphoid"  for k in ["CLP"]},
}
keep = list(lineage_map.keys())

# --- Subset + pull expression ---
adata_sub = dyn_adata[dyn_adata.obs[cluster_key].isin(keep)].copy()
X = adata_sub[:, gene].X
expr = X.toarray().ravel() if sp.issparse(X) else np.asarray(X).ravel()

pt = adata_sub.obs[pt_key].to_numpy(dtype=float)
lineage = adata_sub.obs[cluster_key].map(lineage_map).to_numpy(dtype=object)

m = ~np.isnan(pt)
expr, pt, lineage = expr[m], pt[m], lineage[m]

# --- Bin pseudotime and aggregate mean +/- SE per lineage ---
edges = np.linspace(pt.min(), pt.max(), n_bins + 1)
bin_id = np.clip(np.digitize(pt, edges) - 1, 0, n_bins - 1)
centers = 0.5 * (edges[:-1] + edges[1:])

df = pd.DataFrame({"lineage": lineage, "bin": bin_id, "expr": expr})
agg = (df.groupby(["lineage", "bin"])
         .agg(mean=("expr", "mean"),
              se=("expr", lambda x: x.std(ddof=1)/np.sqrt(len(x)) if len(x) > 1 else 0.0))
         .reset_index())
agg["pt"] = centers[agg["bin"].to_numpy()]

# --- Plot ---
plt.figure(figsize=(7, 5))
for lin in ["erythroid", "myeloid"]:
    sub = agg[agg.lineage == lin].sort_values("pt")
    plt.plot(sub["pt"], sub["mean"], lw=2, label=lin)
    plt.fill_between(sub["pt"], sub["mean"] - sub["se"], sub["mean"] + sub["se"], alpha=0.2)

plt.xlabel("LSD pseudotime")
plt.ylabel("GATA2 expression")
plt.title("GATA2 expression across pseudotime by lineage")
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

## Save outputs for parity testing

The cell below saves key outputs for comparison with LSD-main-branch results.

In [None]:
# Save outputs for parity comparison
output_dir = "../../tests/fixtures/reference_data/tutorial"
os.makedirs(output_dir, exist_ok=True)

# Save key arrays
np.save(os.path.join(output_dir, "lsdpy_pseudotime.npy"), final_adata.obs["lsd_pseudotime"].values)
np.save(os.path.join(output_dir, "lsdpy_potential.npy"), final_adata.obs["potential"].values)
np.save(os.path.join(output_dir, "lsdpy_entropy.npy"), final_adata.obs["entropy"].values)
np.save(os.path.join(output_dir, "lsdpy_cell_state.npy"), final_adata.obsm["X_cell_state"])
np.save(os.path.join(output_dir, "lsdpy_diff_state.npy"), final_adata.obsm["X_diff_state"])

print(f"Saved outputs to {output_dir}")
print(f"  - lsdpy_pseudotime.npy: shape {final_adata.obs['lsd_pseudotime'].shape}")
print(f"  - lsdpy_potential.npy: shape {final_adata.obs['potential'].shape}")
print(f"  - lsdpy_entropy.npy: shape {final_adata.obs['entropy'].shape}")
print(f"  - lsdpy_cell_state.npy: shape {final_adata.obsm['X_cell_state'].shape}")
print(f"  - lsdpy_diff_state.npy: shape {final_adata.obsm['X_diff_state'].shape}")