# Pancreas Dataset - Postprocessing and Evaluation

This notebook demonstrates postprocessing of LSD results including:
- Velocity projection using CellRank
- Cross-Boundary Direction Correctness (CBDir) evaluation
- Streamline visualization
- Cell fate analysis

**Prerequisites:**
- Complete the `train.ipynb` notebook first
- Have a trained model saved in `./pancreas_model/`

## Setup

In [None]:
import os
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import torch

# Import lsdpy components
from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state
from sclsd.analysis.metrics import (
    cross_boundary_correctness,
    inner_cluster_coh,
    evaluate,
)

# CellRank for velocity projection
from cellrank.kernels import ConnectivityKernel

# Set random seed
SEED = 42
set_all_seeds(SEED)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Trained Model and Results

In [None]:
# Path to model directory
model_dir = "./pancreas_model"

# Load result AnnData
result = sc.read(os.path.join(model_dir, "result_adata.h5ad"))
print(f"Loaded {result.n_obs} cells x {result.n_vars} genes")
print(f"\nAvailable obs columns: {result.obs.columns.tolist()}")
print(f"\nAvailable obsm keys: {list(result.obsm.keys())}")
print(f"\nAvailable obsp keys: {list(result.obsp.keys())}")

## Project Velocity onto UMAP using CellRank

We use CellRank's ConnectivityKernel to project the transition probabilities onto the UMAP embedding.

In [None]:
# Create ConnectivityKernel with precomputed transitions
ck = ConnectivityKernel(result)
ck.transition_matrix = result.obsp["transitions"]

print("ConnectivityKernel created with transition matrix")

In [None]:
# Project velocity onto UMAP
ck.plot_projection(
    basis="umap",
    recompute=True,
    color="clusters",
    legend_loc="right",
    title="Velocity Streamlines",
)

In [None]:
# Project velocity colored by pseudotime
ck.plot_projection(
    basis="umap",
    recompute=False,
    color="lsd_pseudotime",
    title="Velocity Streamlines - Pseudotime",
)

## Define Cluster Edges for Evaluation

Cluster edges define the expected transition directions for CBDir evaluation.

In [None]:
# Define cluster edges for Cross-Boundary Direction Correctness
cluster_edges = [
    ("Prlf. Ductal", "Ductal"),
    ("Ductal", "Ngn3 low"),
    ("Ngn3 low", "Ngn3 high"),
    ("Ngn3 high", "Fev+"),
    ("Ngn3 high", "Epsilon"),
    ("Fev+", "Fev+ Alpha"),
    ("Fev+", "Fev+ Beta"),
    ("Fev+ Alpha", "Alpha"),
    ("Fev+ Beta", "Beta"),
    ("Fev+ Delta", "Delta"),
]

print(f"Defined {len(cluster_edges)} cluster edges for evaluation")

## Compute Cross-Boundary Direction Correctness (CBDir)

CBDir measures how well the inferred velocities point from source to target clusters at cluster boundaries.

In [None]:
# First, ensure velocity is projected (add to obsm)
# CellRank's projection adds velocity_umap automatically
if "velocity_umap" not in result.obsm:
    from cellrank.kernels._tmat_flow import TmatProjection
    tp = TmatProjection(ck)
    tp.project(basis="umap", key_added="velocity_umap")
    
print(f"Velocity embedding shape: {result.obsm['velocity_umap'].shape}")

In [None]:
# Compute CBDir scores
cbdir_scores, mean_cbdir = cross_boundary_correctness(
    result,
    k_cluster="clusters",
    k_velocity="velocity",
    cluster_edges=cluster_edges,
    x_emb="X_umap",
)

print("\nCross-Boundary Direction Correctness Scores:")
print("=" * 50)
for edge, score in cbdir_scores.items():
    print(f"  {edge[0]:15} -> {edge[1]:15}: {score:.3f}")
print("=" * 50)
print(f"  Overall Mean CBDir: {mean_cbdir:.3f}")
print(f"  Expected (from LSD paper): ~0.487")

## Compute In-Cluster Coherence (ICCoh)

ICCoh measures how coherent velocity vectors are within each cluster.

In [None]:
# Compute In-Cluster Coherence scores
iccoh_scores, mean_iccoh = inner_cluster_coh(
    result,
    k_cluster="clusters",
    k_velocity="velocity",
)

print("\nIn-Cluster Coherence Scores:")
print("=" * 50)
for cluster, score in iccoh_scores.items():
    print(f"  {cluster:20}: {score:.3f}")
print("=" * 50)
print(f"  Overall Mean ICCoh: {mean_iccoh:.3f}")

## Full Evaluation Report

In [None]:
# Run full evaluation
eval_results = evaluate(
    result,
    cluster_edges=cluster_edges,
    k_cluster="clusters",
    k_velocity="velocity",
    x_emb="X_umap",
    verbose=True,
)

## Visualize Results

In [None]:
# Create figure with multiple panels
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Panel 1: Clusters
sc.pl.umap(result, color='clusters', ax=axes[0, 0], show=False, title='Cell Types')

# Panel 2: LSD Pseudotime
sc.pl.umap(result, color='lsd_pseudotime', ax=axes[0, 1], show=False, title='LSD Pseudotime')

# Panel 3: Potential
sc.pl.umap(result, color='potential', ax=axes[1, 0], show=False, title='Waddington Potential')

# Panel 4: Entropy
sc.pl.umap(result, color='entropy', ax=axes[1, 1], show=False, title='Differentiation Entropy')

plt.tight_layout()
plt.savefig(os.path.join(model_dir, "result_summary.png"), dpi=150, bbox_inches='tight')
plt.show()

## Analyze Differentiation State Space

In [None]:
# Plot cells in differentiation state space (B_loc)
fig, ax = plt.subplots(figsize=(8, 6))

# Get B_loc coordinates
B_loc = result.obsm["X_diff_state"]

# Get cluster colors
clusters = result.obs["clusters"].astype("category")
cluster_colors = dict(zip(
    clusters.cat.categories,
    result.uns.get("clusters_colors", plt.cm.tab20.colors[:len(clusters.cat.categories)])
))

# Scatter plot
for cluster in clusters.cat.categories:
    mask = clusters == cluster
    ax.scatter(
        B_loc[mask, 0],
        B_loc[mask, 1],
        c=[cluster_colors[cluster]],
        label=cluster,
        alpha=0.6,
        s=5,
    )

ax.set_xlabel("B1 (Differentiation State)")
ax.set_ylabel("B2 (Differentiation State)")
ax.set_title("Differentiation State Space")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(model_dir, "diff_state_space.png"), dpi=150, bbox_inches='tight')
plt.show()

## Compare with Prior Pseudotime

In [None]:
# Correlation between LSD pseudotime and prior pseudotime
from scipy.stats import pearsonr, spearmanr

prior_time = result.obs["prior_time"].values
lsd_time = result.obs["lsd_pseudotime"].values

pearson_r, pearson_p = pearsonr(prior_time, lsd_time)
spearman_r, spearman_p = spearmanr(prior_time, lsd_time)

print(f"Correlation between Prior and LSD Pseudotime:")
print(f"  Pearson r:  {pearson_r:.3f} (p={pearson_p:.2e})")
print(f"  Spearman r: {spearman_r:.3f} (p={spearman_p:.2e})")

In [None]:
# Scatter plot of pseudotime comparison
fig, ax = plt.subplots(figsize=(6, 6))

ax.scatter(prior_time, lsd_time, c=result.obs["potential"], cmap='viridis', s=2, alpha=0.5)
ax.set_xlabel("Prior Pseudotime")
ax.set_ylabel("LSD Pseudotime")
ax.set_title(f"Pseudotime Comparison (r={spearman_r:.3f})")

# Add diagonal line
ax.plot([0, 1], [0, 1], 'r--', alpha=0.5, label='Identity')
ax.legend()

plt.tight_layout()
plt.savefig(os.path.join(model_dir, "pseudotime_comparison.png"), dpi=150, bbox_inches='tight')
plt.show()

## Save Evaluation Results

In [None]:
import json

# Save evaluation metrics
eval_summary = {
    "cbdir": {
        "scores": {str(k): v for k, v in cbdir_scores.items()},
        "mean": float(mean_cbdir),
    },
    "iccoh": {
        "scores": {str(k): v for k, v in iccoh_scores.items()},
        "mean": float(mean_iccoh),
    },
    "pseudotime_correlation": {
        "pearson_r": float(pearson_r),
        "spearman_r": float(spearman_r),
    },
}

with open(os.path.join(model_dir, "evaluation_metrics.json"), "w") as f:
    json.dump(eval_summary, f, indent=2)

print(f"Evaluation metrics saved to {model_dir}/evaluation_metrics.json")

## Expected Results Summary

Based on the original LSD paper and LSD-main-branch implementation, the expected results for the Pancreas dataset are:

| Metric | Expected Value |
|--------|----------------|
| Mean CBDir | ~0.487 |
| (Prlf. Ductal, Ductal) | ~0.172 |
| (Ductal, Ngn3 low) | ~0.381 |
| (Ngn3 low, Ngn3 high) | ~0.305 |
| (Ngn3 high, Fev+) | ~0.398 |
| (Ngn3 high, Epsilon) | ~0.715 |
| (Fev+, Fev+ Alpha) | ~0.603 |
| (Fev+, Fev+ Beta) | ~0.463 |
| (Fev+ Alpha, Alpha) | ~0.617 |
| (Fev+ Beta, Beta) | ~0.583 |
| (Fev+ Delta, Delta) | ~0.633 |

If your results match these values (within tolerance), the lsdpy implementation is producing parity results with the original LSD implementation.