# BoneMarrow Dataset - Postprocessing and Evaluation

This notebook demonstrates postprocessing of LSD results including:
- Velocity projection using CellRank
- Cross-Boundary Direction Correctness (CBDir) evaluation
- Streamline visualization
- Cell fate analysis
- Temporal dynamics of cell state (Δz)

**Prerequisites:**
- Complete the `train.ipynb` notebook first
- Have a trained model saved in `./bonemarrow_model/`

## Setup

In [None]:
import os
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# Import lsdpy components
from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state
from sclsd.analysis.metrics import (
    cross_boundary_correctness,
    inner_cluster_coh,
    evaluate,
)

# CellRank for velocity projection
from cellrank.kernels import ConnectivityKernel
from cellrank.kernels.utils import TmatProjection

# Set random seed
SEED = 42
set_all_seeds(SEED)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Trained Model and Results

In [None]:
# Path to data and model directory
data_path = "../../data/BoneMarrow/normalized_before_low.h5ad"
model_dir = "./bonemarrow_model"

# Load original data
adata = sc.read(data_path)

# Initialize LSD model
cfg = LSDConfig()
lsd = LSD(
    adata,
    cfg,
    device=device,
    lib_size_key="librarysize",
    raw_count_key="raw"
)

# Load the trained model
model_path = "lsd_model_epoch0200.pth"
lsd.load(model_dir, model_path)

# Extract LSD outputs (transition matrix, potential, entropy, pseudotime)
result = lsd.get_adata()

print(f"Loaded {result.n_obs} cells x {result.n_vars} genes")
print(f"\nAvailable obs columns: {result.obs.columns.tolist()}")
print(f"\nAvailable obsm keys: {list(result.obsm.keys())}")
print(f"\nAvailable obsp keys: {list(result.obsp.keys())}")

## Predict Cell Fates

In [None]:
# Predict cell fates by ODE propagation
dyn_adata = lsd.get_cell_fates(result, time_range=5, cluster_key="clusters")

# Set fate colors to match cluster colors
dyn_adata.uns["fate_colors"] = dyn_adata.uns["clusters_colors"]

## Visualize LSD Outputs

In [None]:
# Plot LSD outputs
cols = ["potential", "lsd_pseudotime", "fate", "entropy"]

print("=" * 60)
print("Plotting LSD outputs on UMAP projection")
print("=" * 60)

# Use X_umap if available, otherwise fall back to X_tsne
basis = "X_umap" if "X_umap" in dyn_adata.obsm else "X_tsne"
sc.pl.embedding(dyn_adata, color=cols, basis=basis, ncols=2)

In [None]:
print("=" * 60)
print("Plotting LSD outputs on differentiation state projection")
print("=" * 60)
sc.pl.embedding(dyn_adata, color=cols, basis="X_diff_state", ncols=2)

## Plot Velocity Streamlines

In [None]:
print("=" * 60)
print("Plotting LSD streamlines on UMAP projection")
print("=" * 60)
lsd.stream_lines(embedding=basis, size=60)

In [None]:
print("=" * 60)
print("Plotting LSD streamlines on differentiation state projection")
print("=" * 60)
lsd.stream_lines(embedding="X_diff_state", size=60)

## Temporal Dynamics of Cell State (Δz)

In [None]:
# Plot Δz = z(t+1) - z(t), temporal changes of cell state
Z = lsd.z_sol.detach().cpu().numpy()
dZ = np.diff(Z, axis=0)
t = np.linspace(0, 5, dZ.shape[0])

for k in range(Z.shape[-1]):
    mean_k = dZ[:, :, k].mean(axis=1)
    std_k = dZ[:, :, k].std(axis=1)
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.axhline(0, linestyle="--", linewidth=1.0, color="#666666", alpha=0.8, zorder=0)
    ax.plot(t, mean_k, linewidth=2)
    ax.fill_between(t, mean_k - std_k, mean_k + std_k, alpha=0.25, linewidth=0)
    ax.set_title(f"Mean Δz - Component {k+1}", fontsize=12, pad=10)
    ax.set_xlabel("ODE Time Unit", fontsize=11)
    ax.set_ylabel("Δz value", fontsize=11)
    ax.grid(True, alpha=0.3, linestyle="--")
    for s in ["top", "right"]:
        ax.spines[s].set_visible(False)
    for s in ["left", "bottom"]:
        ax.spines[s].set_linewidth(0.8)
    ax.tick_params(axis="both", labelsize=10)
    fig.subplots_adjust(top=0.88, left=0.12, right=0.98, bottom=0.15)
    plt.show()

## Cross-Boundary Direction Correctness (CBDir)

In [None]:
# Define cluster edges for BoneMarrow lineage
cluster_edges = [
    ("HSC_1", "HSC_2"),
    ("HSC_2", "Precursors"),
    ("HSC_1", "Ery_1"),
    ("Ery_1", "Ery_2"),
    ("Precursors", "Mono_1"),
    ("Precursors", "Mono_2"),
    ("Precursors", "DCs"),
]

print(f"Defined {len(cluster_edges)} cluster edges for evaluation")

In [None]:
# Use CellRank for projecting velocity onto embedding
kernel = ConnectivityKernel(result)
kernel.transition_matrix = result.obsp["transitions"]

# Project onto the available embedding
emb_key = "X_umap" if "X_umap" in result.obsm else "X_tsne"
velocity_key = "velocity_umap" if "X_umap" in result.obsm else "velocity_tsne"

proj = TmatProjection(kernel, basis=emb_key)
proj.project(
    key_added=velocity_key,
    recompute=True,
    connectivities=result.obsp["connectivities"]
)

# CellRank expects neighbor indices stored
indices = result.obsp['distances'].tolil().rows
result.uns['neighbors']['indices'] = indices

print(f"Velocity projected onto {emb_key}")

In [None]:
# Compute CBDir scores
all_scores, metrics = cross_boundary_correctness(
    result,
    k_cluster="clusters",
    k_velocity=velocity_key,
    cluster_edges=cluster_edges,
    x_emb=emb_key
)

# Display results
print("\nCross-boundary correctness scores:")
print("=" * 50)
for edge, score in all_scores.items():
    print(f"  {edge[0]:15} -> {edge[1]:15}: {score:.3f}")

# Compute summary statistics
average_all = np.mean(list(all_scores.values()))
print("=" * 50)
print(f"\nOverall average CBDir: {average_all:.3f}")
print(f"Expected (from LSD paper): ~0.594")

## In-Cluster Coherence (ICCoh)

In [None]:
# Compute In-Cluster Coherence scores
iccoh_scores, mean_iccoh = inner_cluster_coh(
    result,
    k_cluster="clusters",
    k_velocity=velocity_key,
)

print("\nIn-Cluster Coherence Scores:")
print("=" * 50)
for cluster, score in iccoh_scores.items():
    print(f"  {cluster:20}: {score:.3f}")
print("=" * 50)
print(f"  Overall Mean ICCoh: {mean_iccoh:.3f}")

## Fate Probability Heatmap

In [None]:
# Plot fate probability heatmap
fate_df = pd.crosstab(dyn_adata.obs["clusters"], dyn_adata.obs["fate"], normalize="index")
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(fate_df, cmap="viridis", ax=ax, annot=True, fmt=".2f")
ax.set_title("Fate Probability Heatmap")
ax.set_xlabel("Predicted Fate")
ax.set_ylabel("Starting Cluster")
plt.tight_layout()
plt.savefig(os.path.join(model_dir, "fate_heatmap.png"), dpi=150, bbox_inches='tight')
plt.show()

## Entropy vs Pseudotime

In [None]:
# Plot entropy vs pseudotime scatter plot
clusters = dyn_adata.obs["clusters"]

# Map cluster category to color
categories = clusters.cat.categories
cluster_colors = dyn_adata.uns["clusters_colors"]
palette = {cat: col for cat, col in zip(categories, cluster_colors)}

# Extract x, y
x = np.asarray(dyn_adata.obs["lsd_pseudotime"])
y = np.asarray(dyn_adata.obs["entropy"])

r = float(np.corrcoef(x, y)[0, 1])

fig, ax = plt.subplots(figsize=(8, 5))

# Scatter by cluster, rasterized
cluster_vals = clusters.to_numpy()
for cat in categories:
    m = cluster_vals == cat
    if not np.any(m):
        continue
    ax.scatter(
        x[m],
        y[m],
        s=10,
        alpha=0.8,
        c=palette[cat],
        edgecolors="none",
        rasterized=True,
        label=str(cat),
    )

ax.set_xlabel("Pseudotime")
ax.set_ylabel("Entropy")
ax.set_title("Bone Marrow - Pseudotime vs Entropy")

# Pearson r in top-right
if np.isfinite(r):
    ax.text(
        0.98,
        0.98,
        rf"$Pearson\ r = {r:.2f}$",
        transform=ax.transAxes,
        ha="right",
        va="top",
    )

for s in ["top", "right"]:
    ax.spines[s].set_visible(False)
for s in ["left", "bottom"]:
    ax.spines[s].set_linewidth(0.8)
ax.grid(True, alpha=0.3, linestyle="--")

ax.legend(
    frameon=False,
    bbox_to_anchor=(1.02, 1.0),
    loc="upper left",
    borderaxespad=0.0,
)
plt.tight_layout()
plt.savefig(os.path.join(model_dir, "entropy_vs_pseudotime.png"), dpi=150, bbox_inches='tight')
plt.show()

## Save Evaluation Results

In [None]:
import json

# Save evaluation metrics
eval_summary = {
    "cbdir": {
        "scores": {str(k): v for k, v in all_scores.items()},
        "mean": float(average_all),
    },
    "iccoh": {
        "scores": {str(k): v for k, v in iccoh_scores.items()},
        "mean": float(mean_iccoh),
    },
    "pseudotime_entropy_correlation": float(r),
}

with open(os.path.join(model_dir, "evaluation_metrics.json"), "w") as f:
    json.dump(eval_summary, f, indent=2)

print(f"Evaluation metrics saved to {model_dir}/evaluation_metrics.json")

## Expected Results Summary

Based on the original LSD paper and LSD-main-branch implementation, the expected results for the BoneMarrow dataset are:

| Metric | Expected Value |
|--------|----------------|
| Mean CBDir | ~0.594 |
| (HSC_1, HSC_2) | ~0.331 |
| (HSC_2, Precursors) | ~0.498 |
| (HSC_1, Ery_1) | ~0.685 |
| (Ery_1, Ery_2) | ~0.834 |
| (Precursors, Mono_1) | ~0.555 |
| (Precursors, Mono_2) | ~0.458 |
| (Precursors, DCs) | ~0.799 |

If your results match these values (within tolerance), the lsdpy implementation is producing parity results with the original LSD implementation.