# Lab 7: Trajectory Differential Expression

**Module 7** - Finding Genes That Change Along Pseudotime

## Objectives
- Identify trajectory-associated genes
- Fit smoothed expression models
- Categorize gene patterns
- Visualize gene dynamics


In [None]:
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

sc.settings.set_figure_params(dpi=100, facecolor='white')

# Load data with pseudotime
try:
    import scvelo as scv
    adata = scv.datasets.pancreas()
except:
    adata = sc.datasets.pbmc3k_processed()

# Compute pseudotime
sc.pp.neighbors(adata)
sc.tl.diffmap(adata)
adata.uns['iroot'] = 0
sc.tl.dpt(adata)

print(f"Pseudotime range: {adata.obs['dpt_pseudotime'].min():.3f} - {adata.obs['dpt_pseudotime'].max():.3f}")


In [None]:
# Simple trajectory DE: correlate gene expression with pseudotime
def trajectory_de(adata, n_genes=100):
    """Find genes correlated with pseudotime"""
    pseudotime = adata.obs['dpt_pseudotime'].values
    
    # Remove cells with undefined pseudotime
    valid = ~np.isnan(pseudotime) & ~np.isinf(pseudotime)
    pt_valid = pseudotime[valid]
    
    correlations = []
    pvalues = []
    
    # Get expression matrix
    X = adata.X[valid] if hasattr(adata.X, 'toarray') else adata.X[valid]
    if hasattr(X, 'toarray'):
        X = X.toarray()
    
    for i in range(min(n_genes, adata.n_vars)):
        expr = X[:, i]
        corr, pval = stats.spearmanr(pt_valid, expr)
        correlations.append(corr)
        pvalues.append(pval)
    
    return np.array(correlations), np.array(pvalues)

corrs, pvals = trajectory_de(adata)
print(f"Computed correlations for {len(corrs)} genes")


In [None]:
# Top positively and negatively correlated genes
n_top = 10
top_pos_idx = np.argsort(corrs)[-n_top:][::-1]
top_neg_idx = np.argsort(corrs)[:n_top]

print("Top genes INCREASING along pseudotime:")
for i in top_pos_idx:
    print(f"  {adata.var_names[i]}: r={corrs[i]:.3f}")

print("\nTop genes DECREASING along pseudotime:")
for i in top_neg_idx:
    print(f"  {adata.var_names[i]}: r={corrs[i]:.3f}")


In [None]:
# Visualize gene expression along pseudotime
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

# Plot top 3 increasing genes
for i, idx in enumerate(top_pos_idx[:3]):
    gene = adata.var_names[idx]
    ax = axes[0, i]
    
    pt = adata.obs['dpt_pseudotime'].values
    expr = adata[:, gene].X.toarray().flatten() if hasattr(adata[:, gene].X, 'toarray') else adata[:, gene].X.flatten()
    
    ax.scatter(pt, expr, alpha=0.3, s=5)
    ax.set_xlabel('Pseudotime')
    ax.set_ylabel('Expression')
    ax.set_title(f'{gene} (r={corrs[idx]:.3f})')

# Plot top 3 decreasing genes
for i, idx in enumerate(top_neg_idx[:3]):
    gene = adata.var_names[idx]
    ax = axes[1, i]
    
    pt = adata.obs['dpt_pseudotime'].values
    expr = adata[:, gene].X.toarray().flatten() if hasattr(adata[:, gene].X, 'toarray') else adata[:, gene].X.flatten()
    
    ax.scatter(pt, expr, alpha=0.3, s=5)
    ax.set_xlabel('Pseudotime')
    ax.set_ylabel('Expression')
    ax.set_title(f'{gene} (r={corrs[idx]:.3f})')

plt.tight_layout()
plt.show()
