WORKFLOW 06: CellRank Integration for Lineage Analysis
=======================================================

This workflow demonstrates how to integrate PEACH archetypes with CellRank for lineage tracing:
1. Setup CellRank with ConnectivityKernel (no velocity required)
2. Use archetype assignments as terminal states
3. Compute fate probabilities via GPCCA
4. Compute lineage pseudotimes
5. Compute lineage drivers (genes driving fate decisions)
6. Compute transition frequencies between archetypes

NOTE: This workflow uses ConnectivityKernel which does NOT require RNA velocity.
CellRank can work with just a k-NN graph from PCA coordinates.

Requirements:
- peach
- scanpy
- cellrank >= 2.0

In [1]:
import scanpy as sc
import numpy as np
import peach as pc
from pathlib import Path

  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)


## Configuration

In [4]:
# Data path - use helsinki_fit.h5ad for demonstration
data_path = Path("data/helsinki_trained.h5ad")
output_dir = Path("tests")
output_dir.mkdir(exist_ok=True)

# Training parameters
n_archetypes = 5
hidden_dims = [256, 128, 64]
n_epochs = 50
seed = 42

## Step 1: Prepare Data with Model and Assignments

In [5]:
print("=" * 70)
print("WORKFLOW 06: CellRank Integration for Lineage Analysis")
print("=" * 70)

print("\nStep 1: Preparing data...")
adata = sc.read_h5ad(data_path)
print(f"  Shape: {adata.n_obs:,} cells x {adata.n_vars:,} genes")

# Ensure PCA exists
pca_key = 'X_pca' if 'X_pca' in adata.obsm else 'X_PCA'
if pca_key not in adata.obsm:
    print("  Running PCA...")
    sc.tl.pca(adata, n_comps=13)
    pca_key = 'X_pca'
print(f"  PCA key: {pca_key}")

# Check if we need to train - require both distances AND weights to skip training
if 'archetype_distances' in adata.obsm and 'cell_archetype_weights' in adata.obsm:
    print("  Using existing archetypal model from data file")
    n_archetypes = adata.obsm['archetype_distances'].shape[1]
else:
    # Train model (needed to extract weights)
    print(f"  Training model ({n_archetypes} archetypes)...")
    results = pc.tl.train_archetypal(
        adata,
        n_archetypes=n_archetypes,
        n_epochs=n_epochs,
        model_config={'hidden_dims': hidden_dims},
        seed=seed,
        device='cpu',
    )
    # Compute archetype coordinates after training
    print("  Computing archetype coordinates...")
    pc.tl.archetypal_coordinates(adata)

    # Extract archetype weights using the trained model
    print("  Extracting archetype weights...")
    pc.tl.extract_archetype_weights(adata, model=results['model'])

# Assign cells to archetypes
if 'archetypes' not in adata.obs:
    print("  Assigning cells to archetypes...")
    pc.tl.assign_archetypes(adata)

print(f"  Archetype distribution:")
print(adata.obs['archetypes'].value_counts().to_string(header=False))

WORKFLOW 06: CellRank Integration for Lineage Analysis

Step 1: Preparing data...
  Shape: 8,806 cells x 32,847 genes
  PCA key: X_pca
  Using existing archetypal model from data file
  Archetype distribution:
no_archetype    3849
archetype_3      869
archetype_5      863
archetype_4      833
archetype_2      831
archetype_1      830
archetype_0      731


## Step 2: Setup CellRank with ConnectivityKernel

CellRank can work without RNA velocity by using a ConnectivityKernel based on the k-NN graph.
We use archetype assignments as terminal states for GPCCA.

In [6]:
print("\n" + "-" * 70)
print("Step 2: CellRank Setup (ConnectivityKernel)")
print("-" * 70)

# Compute neighbors if not present
n_pcs = min(11, adata.obsm[pca_key].shape[1])
if 'neighbors' not in adata.uns:
    print(f"  Computing neighbors (n_pcs={n_pcs})...")
    sc.pp.neighbors(adata, n_neighbors=30, n_pcs=n_pcs, use_rep=pca_key)

# Setup CellRank using archetype assignments as terminal states
print("  Setting up CellRank with archetype terminal states...")
ck, g = pc.tl.setup_cellrank(
    adata,
    terminal_obs_key='archetypes',  # Use archetypes as terminal states
    n_neighbors=30,
    n_pcs=n_pcs,
    compute_paga=False,
    tol=1e-7,  # More permissive GMRES tolerance
    verbose=True
)

print("  CellRank kernel initialized")
print(f"  GPCCA estimator stored in adata.uns['cellrank_gpcca']")
print(f"  Fate probabilities shape: {adata.obsm['fate_probabilities'].shape}")
print(f"  Lineage names: {adata.uns['lineage_names']}")


----------------------------------------------------------------------
Step 2: CellRank Setup (ConnectivityKernel)
----------------------------------------------------------------------
  Computing neighbors (n_pcs=11)...


2026-01-08 17:19:24.609424: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


  Setting up CellRank with archetype terminal states...
CellRank Workflow Setup

1. [OK] Neighbors already computed

2. Computing UMAP...

3. Defining high-purity cells (threshold=0.8)...
archetype_1: 166 high-purity cells (threshold=0.158)
archetype_2: 167 high-purity cells (threshold=0.240)
archetype_3: 174 high-purity cells (threshold=0.663)
archetype_4: 167 high-purity cells (threshold=0.205)
archetype_5: 173 high-purity cells (threshold=0.101)

4. Skipping PAGA (compute_paga=False)

5. Building connectivity kernel...
   [OK] Kernel shape: (8806, 8806)

6. Setting terminal states...
   archetype_1: 166 terminal cells
   archetype_2: 167 terminal cells
   archetype_3: 174 terminal cells
   archetype_4: 167 terminal cells
   archetype_5: 173 terminal cells
   [OK] Total terminal cells: 847

7. Computing fate probabilities (solver='gmres')...
Defaulting to `'gmres'` solver.


  0%|          | 0/5 [00:00<?, ?/s]

   [OK] Fate probabilities: (8806, 5)
   Lineages: ['archetype_1', 'archetype_2', 'archetype_3', 'archetype_4', 'archetype_5']

[OK] CellRank workflow complete

Stored in AnnData:
  adata.obs['terminal_states']
  adata.obsm['fate_probabilities']
  adata.uns['lineage_names']
  CellRank kernel initialized
  GPCCA estimator stored in adata.uns['cellrank_gpcca']
  Fate probabilities shape: (8806, 5)
  Lineage names: ['archetype_1', 'archetype_2', 'archetype_3', 'archetype_4', 'archetype_5']


## Step 3: Compute Lineage Pseudotimes

Pseudotimes represent progression toward each terminal state (archetype).

In [7]:
print("\n" + "-" * 70)
print("Step 3: Lineage Pseudotimes")
print("-" * 70)

print("  Computing pseudotimes for each lineage...")
pc.tl.compute_lineage_pseudotimes(adata)

# Report pseudotime stats
for lineage in adata.uns['lineage_names']:
    pt_key = f'pseudotime_to_{lineage}'
    if pt_key in adata.obs.columns:
        pt_vals = adata.obs[pt_key].dropna()
        print(f"    {lineage}: {len(pt_vals):,} cells, "
              f"range [{pt_vals.min():.3f}, {pt_vals.max():.3f}]")


----------------------------------------------------------------------
Step 3: Lineage Pseudotimes
----------------------------------------------------------------------
  Computing pseudotimes for each lineage...
Computing pseudotimes for 5 lineages...
  archetype_1: stored as 'pseudotime_to_archetype_1'
  archetype_2: stored as 'pseudotime_to_archetype_2'
  archetype_3: stored as 'pseudotime_to_archetype_3'
  archetype_4: stored as 'pseudotime_to_archetype_4'
  archetype_5: stored as 'pseudotime_to_archetype_5'
[OK] Pseudotimes computed
    archetype_1: 8,806 cells, range [0.000, 1.000]
    archetype_2: 8,806 cells, range [0.000, 1.000]
    archetype_3: 8,806 cells, range [0.000, 1.000]
    archetype_4: 8,806 cells, range [0.000, 1.000]
    archetype_5: 8,806 cells, range [0.000, 1.000]


## Step 4: Compute Lineage Drivers

Lineage drivers are genes whose expression correlates with fate probability toward a terminal state.

In [8]:
print("\n" + "-" * 70)
print("Step 4: Lineage Drivers")
print("-" * 70)

# Pick a target lineage (first archetype)
target_lineage = adata.uns['lineage_names'][0]
print(f"  Computing drivers for lineage: {target_lineage}")

drivers = pc.tl.compute_lineage_drivers(
    adata,
    lineage=target_lineage,
    n_genes=50,
    method='correlation'
)

if drivers is not None and len(drivers) > 0:
    print(f"  Top 10 driver genes for {target_lineage}:")
    for i, row in drivers.head(10).iterrows():
        print(f"    {row['gene']:15} corr={row['correlation']:+.3f} p={row['pvalue']:.2e}")
else:
    print("  No drivers computed (check gene expression data)")


----------------------------------------------------------------------
Step 4: Lineage Drivers
----------------------------------------------------------------------
  Computing drivers for lineage: archetype_1
  Top 10 driver genes for archetype_1:
    RARRES1         corr=+0.626 p=0.00e+00
    RPS27           corr=+0.583 p=0.00e+00
    PDZK1IP1        corr=+0.582 p=0.00e+00
    RPS6            corr=+0.554 p=0.00e+00
    CDKN2A          corr=+0.539 p=0.00e+00
    SOD2            corr=+0.539 p=0.00e+00
    LTF             corr=+0.510 p=0.00e+00
    STXBP6          corr=+0.507 p=0.00e+00
    NFIB            corr=+0.507 p=0.00e+00
    CHI3L1          corr=+0.504 p=0.00e+00


## Step 5: Compute Transition Frequencies

Transition frequencies show the probability of cells transitioning between archetypes.

In [9]:
print("\n" + "-" * 70)
print("Step 5: Transition Frequencies")
print("-" * 70)

transitions = pc.tl.compute_transition_frequencies(adata)
print(f"  Transition matrix shape: {transitions.shape}")
print("  Transition frequencies (archetype -> archetype):")
print(transitions.to_string())


----------------------------------------------------------------------
Step 5: Transition Frequencies
----------------------------------------------------------------------
  Transition matrix shape: (5, 5)
  Transition frequencies (archetype -> archetype):
             archetype_1  archetype_2  archetype_3  archetype_4  archetype_5
archetype_1            0            0            0          303            9
archetype_2            0            0            0           21          349
archetype_3            0            0         2118            0            0
archetype_4            0          243            0            0            0
archetype_5         1426            0            0            0            0


## Step 6: Visualization

Visualize fate probabilities and pseudotimes in archetypal space.

In [10]:
print("\n" + "-" * 70)
print("Step 6: Visualization")
print("-" * 70)

# Add fate probability as obs column for plotting
fate_col = f'fate_to_{target_lineage}'
lineage_idx = adata.uns['lineage_names'].index(target_lineage)
adata.obs[fate_col] = adata.obsm['fate_probabilities'][:, lineage_idx]

# 3D archetypal space colored by fate probability
print(f"  Creating 3D archetypal space visualization (colored by {fate_col})...")
fig = pc.pl.archetypal_space(
    adata,
    color_by=fate_col,
    color_scale='viridis',
    title=f'Fate Probability toward {target_lineage}'
)

if fig is not None:
    output_path = output_dir / 'workflow_06_archetypal_space_fate.html'
    fig.write_html(str(output_path))
    print(f"    Saved: {output_path}")


----------------------------------------------------------------------
Step 6: Visualization
----------------------------------------------------------------------
  Creating 3D archetypal space visualization (colored by fate_to_archetype_1)...
    Saved: tests/workflow_06_archetypal_space_fate.html


## Summary

In [11]:
print("\n" + "=" * 70)
print("WORKFLOW 06 COMPLETE")
print("=" * 70)

print("\nKey Outputs:")
print(f"  adata.obs['archetypes']          - Archetype assignments")
print(f"  adata.obsm['fate_probabilities'] - Fate probs ({adata.obsm['fate_probabilities'].shape})")
print(f"  adata.uns['lineage_names']       - {adata.uns['lineage_names']}")
print(f"  adata.uns['cellrank_gpcca']      - GPCCA estimator (for downstream)")

print("\nPseudotime keys in adata.obs:")
pt_keys = [k for k in adata.obs.columns if k.startswith('pseudotime_')]
for key in pt_keys:
    print(f"  {key}")

print("\nCellRank Functions Used:")
print("  pc.tl.setup_cellrank(adata, terminal_obs_key='archetypes')")
print("  pc.tl.compute_lineage_pseudotimes(adata)")
print("  pc.tl.compute_lineage_drivers(adata, lineage=..., method='correlation')")
print("  pc.tl.compute_transition_frequencies(adata)")

print("\nNext workflows:")
print("  08_visualization: Comprehensive Visualization")
print("=" * 70)


WORKFLOW 06 COMPLETE

Key Outputs:
  adata.obs['archetypes']          - Archetype assignments
  adata.obsm['fate_probabilities'] - Fate probs ((8806, 5))
  adata.uns['lineage_names']       - ['archetype_1', 'archetype_2', 'archetype_3', 'archetype_4', 'archetype_5']
  adata.uns['cellrank_gpcca']      - GPCCA estimator (for downstream)

Pseudotime keys in adata.obs:
  pseudotime_to_archetype_1
  pseudotime_to_archetype_2
  pseudotime_to_archetype_3
  pseudotime_to_archetype_4
  pseudotime_to_archetype_5

CellRank Functions Used:
  pc.tl.setup_cellrank(adata, terminal_obs_key='archetypes')
  pc.tl.compute_lineage_pseudotimes(adata)
  pc.tl.compute_lineage_drivers(adata, lineage=..., method='correlation')
  pc.tl.compute_transition_frequencies(adata)

Next workflows:
  08_visualization: Comprehensive Visualization
