# Lab 9D: CellRank 2 — Kernel Combination & Multiview Analysis

**Module 12** — Combining Data Views for Robust Fate Mapping

## Objectives
- Combine multiple CellRank kernels to leverage complementary data views
- Compare single-kernel vs combined-kernel fate probabilities
- Evaluate kernel agreement and identify biologically meaningful disagreements
- Build a complete multiview CellRank 2 workflow

## The Central Idea
Different kernels capture different aspects of cellular dynamics:
- **VelocityKernel**: local RNA splicing dynamics
- **CytoTRACEKernel**: global developmental potential
- **PseudotimeKernel**: computational ordering
- **RealTimeKernel**: experimental time

**Combining** them produces a more robust estimate of cell fate.

## Reference
- Weiler & Theis (2026) *Nature Protocols* — Fig. 5: Kernel combination
- Weiler et al. (2024) "CellRank 2" *Nat. Methods* 21:1196-1205

---

## 1. Setup & Data Preparation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import scvelo as scv
import cellrank as cr
from cellrank.kernels import (
    VelocityKernel,
    ConnectivityKernel,
    CytoTRACEKernel,
    PseudotimeKernel,
)
from cellrank.estimators import GPCCA

sc.settings.set_figure_params(dpi=100, facecolor='white')
cr.settings.verbosity = 2

print(f"scanpy:   {sc.__version__}")
print(f"scvelo:   {scv.__version__}")
print(f"cellrank: {cr.__version__}")

In [None]:
# Load the pancreas dataset with velocity data
adata = scv.datasets.pancreas()

# scVelo preprocessing + velocity
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)
scv.pp.moments(adata, n_pcs=30, n_neighbors=30)
scv.tl.recover_dynamics(adata, n_jobs=4)
scv.tl.velocity(adata, mode='dynamical')
scv.tl.velocity_graph(adata)

# Compute DPT for PseudotimeKernel
sc.tl.diffmap(adata, n_comps=15)
# Set root in Ductal cells
ductal_mask = adata.obs['clusters'] == 'Ductal'
if ductal_mask.sum() > 0:
    adata.uns['iroot'] = np.where(ductal_mask)[0][0]
else:
    adata.uns['iroot'] = 0
sc.tl.dpt(adata)

print(f"Cells: {adata.n_obs}, Genes: {adata.n_vars}")
print(f"Available data: velocity + DPT pseudotime + gene counts (for CytoTRACE)")

## 2. Build Individual Kernels

First, compute each kernel's transition matrix independently.

In [None]:
# Kernel 1: VelocityKernel
vk = VelocityKernel(adata)
vk.compute_transition_matrix()
print(f"VelocityKernel: {vk.transition_matrix.shape}")

# Kernel 2: ConnectivityKernel (transcriptomic similarity)
ck = ConnectivityKernel(adata)
ck.compute_transition_matrix()
print(f"ConnectivityKernel: {ck.transition_matrix.shape}")

# Kernel 3: CytoTRACEKernel
ctk = CytoTRACEKernel(adata)
ctk.compute_transition_matrix()
print(f"CytoTRACEKernel: {ctk.transition_matrix.shape}")

# Kernel 4: PseudotimeKernel
ptk = PseudotimeKernel(adata, time_key='dpt_pseudotime')
ptk.compute_transition_matrix()
print(f"PseudotimeKernel: {ptk.transition_matrix.shape}")

## 3. Kernel Combination: Weighted Addition

CellRank 2 supports **kernel arithmetic**: you can add kernels with weights.
The combined transition matrix is a weighted sum of individual transition matrices.

```python
combined = w1 * kernel1 + w2 * kernel2 + ...  # weights must sum to 1
```

**How to choose weights?**
- Start with equal weights
- Give higher weight to more reliable data views
- Use biological knowledge and quality metrics

In [None]:
# Combination 1: Velocity + Connectivity (classic CellRank 1 approach)
combo_vc = 0.8 * vk + 0.2 * ck

# Combination 2: Velocity + CytoTRACE (two independent directional signals)
combo_vct = 0.6 * vk + 0.4 * ctk

# Combination 3: Three-way combination
combo_all3 = 0.5 * vk + 0.3 * ctk + 0.2 * ck

print("Three kernel combinations ready:")
print(f"  combo_vc:   0.8*Velocity + 0.2*Connectivity")
print(f"  combo_vct:  0.6*Velocity + 0.4*CytoTRACE")
print(f"  combo_all3: 0.5*Velocity + 0.3*CytoTRACE + 0.2*Connectivity")

## 4. Run GPCCA on Each Combination

Now we run the same downstream analysis (macrostates, terminal states,
fate probabilities) on each combination and compare.

In [None]:
def run_cellrank_pipeline(kernel, name, n_states=5):
    """Run full CellRank pipeline on a kernel and return results."""
    g = GPCCA(kernel)
    g.compute_schur(n_components=20)
    g.compute_macrostates(n_states=n_states, cluster_key='clusters')
    g.set_terminal_states()
    g.compute_fate_probabilities()
    
    print(f"\n--- {name} ---")
    print(f"Terminal states: {g.terminal_states.cat.categories.tolist()}")
    print(f"Lineages: {g.fate_probabilities.names.tolist()}")
    
    return g

# Run on individual kernels
g_vk = run_cellrank_pipeline(vk, 'VelocityKernel only')
g_ctk = run_cellrank_pipeline(ctk, 'CytoTRACEKernel only')
g_ptk = run_cellrank_pipeline(ptk, 'PseudotimeKernel only')

In [None]:
# Run on combined kernels
g_vc = run_cellrank_pipeline(combo_vc, 'Velocity + Connectivity')
g_vct = run_cellrank_pipeline(combo_vct, 'Velocity + CytoTRACE')
g_all3 = run_cellrank_pipeline(combo_all3, 'Velocity + CytoTRACE + Connectivity')

## 5. Compare Fate Probabilities Across Approaches

In [None]:
# Side-by-side comparison of fate probabilities for a shared lineage
results = {
    'VelocityKernel': g_vk,
    'CytoTRACEKernel': g_ctk,
    'PseudotimeKernel': g_ptk,
    'Velocity+CytoTRACE': g_vct,
    'Three-way combo': g_all3,
}

# Find common lineages across all approaches
all_lineages = [set(r.fate_probabilities.names.tolist()) for r in results.values()]
common_lineages = set.intersection(*all_lineages) if all_lineages else set()

print(f"Common lineages across all approaches: {common_lineages}")

if common_lineages:
    lineage = list(common_lineages)[0]
    print(f"\nComparing fate probabilities for: {lineage}")
    
    n_approaches = len(results)
    fig, axes = plt.subplots(1, n_approaches, figsize=(5 * n_approaches, 4))
    
    for idx, (name, g_result) in enumerate(results.items()):
        fate_vals = g_result.fate_probabilities[lineage].X.flatten()
        col_name = f'fate_{name.replace(" ", "_").replace("+", "")}'
        adata.obs[col_name] = fate_vals
        sc.pl.umap(adata, color=col_name, ax=axes[idx], show=False,
                   title=name, vmin=0, vmax=1)
    
    plt.tight_layout()
    plt.show()
else:
    print("No common lineages — terminal states differ across approaches.")
    print("This itself is an important finding! Different views see different endpoints.")

In [None]:
# Quantitative comparison: pairwise correlation of fate probabilities
from scipy.stats import spearmanr
import pandas as pd

if common_lineages:
    lineage = list(common_lineages)[0]
    
    # Extract fate probabilities for this lineage from each approach
    fate_dict = {}
    for name, g_result in results.items():
        fate_dict[name] = g_result.fate_probabilities[lineage].X.flatten()
    
    # Pairwise Spearman correlations
    names = list(fate_dict.keys())
    corr_matrix = np.zeros((len(names), len(names)))
    
    for i, n1 in enumerate(names):
        for j, n2 in enumerate(names):
            corr, _ = spearmanr(fate_dict[n1], fate_dict[n2])
            corr_matrix[i, j] = corr
    
    corr_df = pd.DataFrame(corr_matrix, index=names, columns=names)
    
    # Heatmap
    fig, ax = plt.subplots(figsize=(8, 6))
    im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    ax.set_xticks(range(len(names)))
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_yticks(range(len(names)))
    ax.set_yticklabels(names)
    plt.colorbar(im, label='Spearman correlation')
    ax.set_title(f'Fate Probability Correlation ({lineage})')
    
    # Annotate cells
    for i in range(len(names)):
        for j in range(len(names)):
            ax.text(j, i, f'{corr_matrix[i,j]:.2f}', ha='center', va='center',
                    fontsize=9, color='white' if abs(corr_matrix[i,j]) > 0.5 else 'black')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nCorrelation matrix for lineage '{lineage}':")
    print(corr_df.to_string(float_format='%.3f'))

## 6. Interpreting Kernel (Dis)Agreement

### High agreement between kernels
- **Strengthens confidence** in the trajectory direction and fate assignments
- Different data views converge on the same biology

### Disagreement between kernels
- Can reveal **interesting biology** (e.g., a population that appears terminally differentiated
  by CytoTRACE but has active splicing suggesting further transitions)
- May indicate **data quality issues** in one view (e.g., poor velocity quality)
- May highlight **limitations of a specific assumption** (e.g., CytoTRACE assumption
  fails for some cell types)

### What to report
1. Results from each kernel individually
2. Combined kernel results
3. Agreement/disagreement analysis
4. Biological interpretation of any disagreements

In [None]:
# Identify cells with high disagreement between two kernels
if common_lineages:
    lineage = list(common_lineages)[0]
    
    fate_vk = g_vk.fate_probabilities[lineage].X.flatten()
    fate_ctk = g_ctk.fate_probabilities[lineage].X.flatten()
    
    # Absolute difference
    disagreement = np.abs(fate_vk - fate_ctk)
    adata.obs['kernel_disagreement'] = disagreement
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    sc.pl.umap(adata, color='kernel_disagreement', ax=axes[0], show=False,
               title=f'|VelocityKernel - CytoTRACEKernel|\nfor {lineage}',
               cmap='Reds')
    sc.pl.umap(adata, color='clusters', ax=axes[1], show=False,
               title='Cell Types')
    sc.pl.umap(adata, color='ct_pseudotime', ax=axes[2], show=False,
               title='CytoTRACE Pseudotime')
    plt.tight_layout()
    plt.show()
    
    # Which clusters have the highest disagreement?
    disagree_by_cluster = adata.obs.groupby('clusters')['kernel_disagreement'].mean()
    print("Mean kernel disagreement by cluster:")
    print(disagree_by_cluster.sort_values(ascending=False).to_string())

## 7. Complete Multiview Workflow Summary

Here's the recommended workflow for a thorough CellRank 2 analysis:

In [None]:
# ============================================================
# COMPLETE CellRank 2 MULTIVIEW WORKFLOW
# ============================================================

# Step 1: Run each available kernel individually
print("="*60)
print("STEP 1: Individual kernels")
print("="*60)

# VelocityKernel (if velocity data available)
vk = VelocityKernel(adata).compute_transition_matrix()

# CytoTRACEKernel (always available)
ctk = CytoTRACEKernel(adata).compute_transition_matrix()

# PseudotimeKernel (if pseudotime computed)
ptk = PseudotimeKernel(adata, time_key='dpt_pseudotime').compute_transition_matrix()

# ConnectivityKernel (always available, for smoothing)
ck = ConnectivityKernel(adata).compute_transition_matrix()

print("All kernels computed.")

# Step 2: Evaluate each kernel's fate mapping
print("\n" + "="*60)
print("STEP 2: Individual kernel fate mapping")
print("="*60)

for kernel, name in [(vk, 'Velocity'), (ctk, 'CytoTRACE'), (ptk, 'Pseudotime')]:
    g_tmp = GPCCA(kernel)
    g_tmp.compute_schur(n_components=15)
    g_tmp.compute_macrostates(n_states=5, cluster_key='clusters')
    g_tmp.set_terminal_states()
    print(f"  {name}: terminal = {g_tmp.terminal_states.cat.categories.tolist()}")

# Step 3: Combine and run final analysis
print("\n" + "="*60)
print("STEP 3: Combined kernel (recommended final analysis)")
print("="*60)

final_kernel = 0.5 * vk + 0.3 * ctk + 0.2 * ck
g_final = GPCCA(final_kernel)
g_final.compute_schur(n_components=20)
g_final.compute_macrostates(n_states=5, cluster_key='clusters')
g_final.set_terminal_states()
g_final.compute_fate_probabilities()

print(f"  Combined terminal: {g_final.terminal_states.cat.categories.tolist()}")
print(f"  Fate probabilities: {g_final.fate_probabilities.shape}")

# Step 4: Driver genes from the combined analysis
print("\n" + "="*60)
print("STEP 4: Driver genes")
print("="*60)

for lin in g_final.terminal_states.cat.categories:
    drivers = g_final.compute_lineage_drivers(lineages=lin, return_drivers=True)
    top5 = drivers.head(5).index.tolist()
    print(f"  {lin}: {top5}")

print("\nDone! Full multiview CellRank 2 analysis complete.")

## 8. Exercises

### Exercise 9D.1: Weight Sensitivity Analysis
Create a grid of weight combinations for VelocityKernel + CytoTRACEKernel:
- (1.0, 0.0), (0.8, 0.2), (0.6, 0.4), (0.4, 0.6), (0.2, 0.8), (0.0, 1.0)
Pick one terminal state and plot how its fate probability changes across weights.

### Exercise 9D.2: Kernel Selection Rationale
For each scenario below, recommend which kernel(s) to use and why:
1. Human iPSC → cardiomyocyte differentiation (no velocity data)
2. Mouse hematopoiesis with 4-day time course
3. Tumor samples with no time information and questionable velocity
4. Metabolic labeling experiment in neuronal differentiation

### Exercise 9D.3: Publication-Ready Report
Write a methods section (200 words) describing your multiview CellRank 2 analysis.
Include: which kernels you used, why, how you combined them, and how you validated.

### Exercise 9D.4: Driver Gene Comparison
Compare the top 20 driver genes from:
1. VelocityKernel alone
2. CytoTRACEKernel alone
3. Combined kernel
How much overlap is there? Are the combined driver genes more biologically interpretable?

---

## Key Takeaways

1. **CellRank 2's key innovation**: combine any data views via kernel arithmetic
2. Combined kernels are more robust than any single kernel
3. Kernel (dis)agreement is informative — both agreement and disagreement are valuable
4. The same downstream API (GPCCA, fate probabilities, driver genes) works for any kernel
5. Always report individual kernel results alongside combined results
6. Weight selection should be guided by data quality and biological knowledge

---

## What's Next?

You've completed the CellRank 2 module! Next:
- **Lab 10**: Trajectory differential expression + CellRank driver genes
- **Lab 11**: Gene expression trend visualization
- **Lab 12**: Method comparison (classical + CellRank approaches)
- **Assignment 4**: Complete multiview CellRank 2 analysis on your chosen dataset