WORKFLOW 05: Gene & Pathway Enrichment Analysis
================================================

This workflow demonstrates how to identify genes and pathways associated with each archetype:
1. Compute gene associations (differential expression per archetype)
2. Compute pathway associations (enrichment analysis)

The gene_associations function returns a DataFrame with columns:
- gene: Gene identifier
- archetype: Which archetype (0, 1, 2, ...)
- n_archetype_cells: Number of cells in this archetype
- n_other_cells: Number of cells in other archetypes
- mean_archetype: Mean expression in archetype cells
- mean_other: Mean expression in other cells
- log_fold_change: Log fold change (archetype vs others)
- statistic: Test statistic
- pvalue: Raw p-value
- fdr_pvalue: FDR-adjusted p-value
- significant: Boolean significance flag
- direction: 'up' or 'down'

Example usage:
    python WORKFLOW_05.py

Requirements:
    - peach
    - scanpy
    - Trained model with archetype assignments (from WORKFLOW_04)

In [12]:
import scanpy as sc
import peach as pc
from pathlib import Path

## Configuration

In [None]:
# Data path
data_path = Path("~/data/hsc_10k.h5ad")

# Training parameters (from WORKFLOW_03-04)
n_archetypes = 5
hidden_dims = [256, 128, 64]
n_epochs = 50
seed = 42

# Gene association parameters
top_n_genes = 50  # Number of top genes to display per archetype
p_value_threshold = 0.05  # Significance threshold

## Step 1: Prepare Data with Model and Assignments (Prerequisites)

In [14]:
print("Preparing data (loading, training, assigning)...")
adata = sc.read_h5ad(data_path)
print(f"  Shape: {adata.n_obs:,} cells × {adata.n_vars:,} genes")

# Ensure PCA exists
if 'X_pca' not in adata.obsm:
    print("  Running PCA...")
    sc.tl.pca(adata, n_comps=13)

# Train model
print(f"  Training model ({n_archetypes} archetypes)...")
pc.tl.train_archetypal(
    adata,
    n_archetypes=n_archetypes,
    n_epochs=n_epochs,
    model_config={'hidden_dims': hidden_dims},
    seed=seed,
    device='cpu',
)

Preparing data (loading, training, assigning)...
  Shape: 10,000 cells × 2,500 genes
  Training model (5 archetypes)...
[OK] Using specified PCA coordinates: adata.obsm['X_pca'] (10000, 50)
[STATS] DataLoader created: 10000 cells × 50 PCA components
   Config: batch_size=128, workers=0 (Apple Silicon)
Archetypes parameter registered: True
Archetypes requires_grad: True
Deep_AA (Deep Archetypal Analysis) initialized:
  - Single-stage architecture (like Deep_2)
  - Inflation factor: 1.5
  - Direct archetypal coordinates (no bottleneck)
 Initializing with PCHA + inflation_factor=1.5...

 Consolidated Archetype Initialization
   PCHA: True, Inflation: True (factor: 1.5)
   Test inflation: False
Running PCHA initialization...
  Input shape: (1000, 50)
  Target archetypes: 5
Running PCHA with 5 archetypes...
Data shape for PCHA: (50, 1000)
PCHA Results:
  Archetypes shape: (5, 50)
  Archetype R²: 0.4446
  SSE: 182614.4223
  PCHA archetype R²: 0.4446
  Archetype shape: (5, 50)
[OK] Initialize

{'history': {'loss': [5.83913189851785,
   4.326342682295207,
   3.859934900380388,
   3.731078322929672,
   3.707883783533603,
   3.68909149230281,
   3.6781850223299823,
   3.6803502010393747,
   3.686174745801129,
   3.6782661178444007,
   3.696598252163658,
   3.6836404045925866,
   3.681182254718829,
   3.67455922199201,
   3.674242402933821,
   3.661227177970017,
   3.666130929053584,
   3.6734456653836407,
   3.682069048096862,
   3.664573286153093,
   3.6739732555196256,
   3.6662653277191937,
   3.681548402279238,
   3.665349088137663,
   3.6697103343432462,
   3.6785570277443416,
   3.6641807435434077,
   3.6585089828394635,
   3.6643326463578623,
   3.6486643658408635,
   3.6613678600214703,
   3.6678096614306486,
   3.6496482016165044,
   3.660446302800239,
   3.6525057237359544,
   3.648253815083564,
   3.657275966451138,
   3.6610085722766343,
   3.656377348718764,
   3.663751107227953,
   3.670903097225141,
   3.66756794120692,
   3.6595264990118483,
   3.647675016258336

In [15]:
pc.tl.archetypal_coordinates(adata)

 Computing archetype distances in PCA space...
   Canonical reference: adata.obs.index (10000 cells)
   Found PCA coordinates: X_pca (10000, 50)
   Found archetype coordinates: archetype_coordinates (5, 50)
 Computing pairwise distances in PCA space...
[OK] Distance computation complete
   Distance matrix shape: (10000, 5)
[OK] Stored in AnnData:
   adata.obsm['archetype_distances']: (10000, 5) distance matrix
   adata.uns['archetype_positions']: (5, 50) archetype positions
   adata.uns['archetype_distance_info']: distance computation metadata

[STATS] Distance Statistics:
   Nearest archetype distribution:
      Archetype 0: 785 cells (7.8%), mean distance: 15.1040
      Archetype 1: 1486 cells (14.9%), mean distance: 17.0829
      Archetype 2: 3987 cells (39.9%), mean distance: 16.8960
      Archetype 3: 3127 cells (31.3%), mean distance: 15.3948
      Archetype 4: 615 cells (6.2%), mean distance: 26.0079
   Overall statistics:
      Mean nearest distance: 16.8741
      Distance rang

Unnamed: 0,cell_id,cell_idx,archetype_0_distance,archetype_1_distance,archetype_2_distance,archetype_3_distance,archetype_4_distance,nearest_archetype,nearest_distance,mean_distance,std_distance
0,young4_TCAATCTGTACATCCA-1,0,30.868059,34.729964,10.722053,27.453603,33.294571,2,10.722053,27.413650,8.702386
1,Oetjen_A_TATTACCTCTGGAGCC-1,1,34.567422,33.966304,26.612939,14.541995,35.431334,3,14.541995,29.023999,7.896221
2,young4_CACAAACTCAAACCGT-1,2,20.907336,34.237827,18.358518,27.472408,34.075530,2,18.358518,27.010324,6.549316
3,MantonBM4_HiSeq_6-TTGTAGGTCTGTTGAG-1,3,36.648242,16.944750,27.587736,30.461982,34.720264,1,16.944750,29.272595,6.932625
4,BMMC_10x_GREENLEAF_REP2:GTGGTTAGTGTAAACA-1,4,34.135876,35.857042,25.095644,14.220676,34.981425,3,14.220676,28.858133,8.279438
...,...,...,...,...,...,...,...,...,...,...,...
9995,Oetjen_P_CGTCCATTCGCCCTTA-1,9995,32.290439,33.789671,25.598921,11.978496,34.261287,3,11.978496,27.583763,8.399040
9996,elderly2_GAAGCAGCAGCATGAG-1,9996,30.041344,31.902685,13.483154,24.355426,32.360471,2,13.483154,26.428616,7.072302
9997,MantonBM1_HiSeq_5-CGAGAAGAGCAGCCTC-1,9997,33.951706,33.183307,25.140412,13.244904,34.424874,3,13.244904,27.989040,8.117364
9998,BMMC_10x_GREENLEAF_REP1:CCCGAAGAGACGACTG-1,9998,44.436055,22.912607,38.325610,37.708057,44.468478,1,22.912607,37.570161,7.876051


In [16]:
pc.tl.assign_archetypes(adata)

 AnnData-centric archetype binning...
   Distance matrix: (10000, 5) (from adata.obsm['archetype_distances'])
   Canonical cell reference: adata.obs.index (10000 cells)
   Selecting top 1000 cells (10.0%) per archetype
   INCLUDING central archetype_0 (generalist cells)
   Archetype 0 (central): 1000 cells, centroid distance range: [25.4429, 26.7755], mean: 26.3968
   Archetype 1: 1000 cells, distance range: [7.5782, 24.5794], mean: 16.7386
   Archetype 2: 1000 cells, distance range: [9.6427, 18.2831], mean: 14.8333
   Archetype 3: 1000 cells, distance range: [8.0890, 12.9625], mean: 11.5525
   Archetype 4: 1000 cells, distance range: [9.4175, 13.6764], mean: 12.7843
   Archetype 5: 1000 cells, distance range: [18.2666, 29.6810], mean: 26.2502

[STATS] Assignment Summary:
   Total cells: 10000
   Archetype 0 (central): 1000 cells (10.0%)
   Archetype 1: 1000 cells (10.0%)
   Archetype 2: 1000 cells (10.0%)
   Archetype 3: 1000 cells (10.0%)
   Archetype 4: 1000 cells (10.0%)
   Archety

In [17]:
print("\nComputing gene associations...")
print("  (Differential expression per archetype)")

gene_assoc = pc.tl.gene_associations(
    adata,
    obs_key='archetypes',         # Column with archetype assignments
    test_method='mannwhitneyu',   # or 'wilcoxon' for non-parametric
    fdr_scope='global',           # Global FDR correction
    verbose=True,
)

print(f"\n  Results DataFrame shape: {gene_assoc.shape}")
print(f"  Columns: {list(gene_assoc.columns)}")

# Show top genes for each archetype
print(f"\nTop genes per archetype:")
for archetype in sorted(gene_assoc['archetype'].unique()):
    arch_genes = gene_assoc[gene_assoc['archetype'] == archetype]

    # Filter by significance
    sig_genes = arch_genes[arch_genes['fdr_pvalue'] < p_value_threshold]

    # Sort by fold change
    top_genes = sig_genes.nlargest(5, 'log_fold_change')

    print(f"\n  Archetype {archetype} ({len(sig_genes)} significant genes):")
    if len(top_genes) > 0:
        for idx, row in top_genes.iterrows():
            gene = row['gene']
            fc = row['log_fold_change']
            pval = row['fdr_pvalue']
            print(f"    {gene}: logFC={fc:.2f}, fdr_p={pval:.2e}")
    else:
        print(f"    No significant genes found")


Computing gene associations...
  (Differential expression per archetype)
🧪 Testing archetype-gene associations (AnnData-centric)...
   Method: mannwhitneyu
   FDR correction: benjamini_hochberg (global scope)
   Test direction: two-sided
   Bin proportion: 0.1 (closest cells to each archetype)
   Minimum cells per archetype: 10
   Comparison group: all
   [OK] AnnData validation passed:
      Distance matrix: (10000, 5) from adata.obsm['archetype_distances']
      Archetype assignments: 10000 cells from adata.obs['archetypes']
      Assignment categories: ['archetype_0', 'archetype_1', 'archetype_2', 'archetype_3', 'archetype_4', 'archetype_5', 'no_archetype']
   Using adata.X
   Expression data: 10000 cells × 2500 genes
   Sparse matrix: 95.5% zeros
   Original format: csr
    Using AnnData archetype assignments for binning...
      Found 6 archetype categories: ['archetype_0', 'archetype_1', 'archetype_2', 'archetype_3', 'archetype_4', 'archetype_5']
         no_archetype: 4637 cell

There's a lot added to .uns at this point ('archetype_coordinates', 'trained_model', 'archetype_positions', and 'archetype_distance_info') but the archetype-cell assignments (default 10%) are stored in .obs for convenience.

In [18]:
adata

AnnData object with n_obs × n_vars = 10000 × 2500
    obs: 'AuthorCellType', 'AuthorCellType_Broad', 'Shannon.Diversity.Normalized', 'nCount_RNA', 'nFeature_RNA', 'Study', 'donor_id', 'Sorting', 'S.Score', 'G2M.Score', 'CyclePhase', 'scrublet_scores', 'assay_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'is_primary_data', 'self_reported_ethnicity_ontology_term_id', 'suspension_type', 'tissue_ontology_term_id', 'sex_ontology_term_id', 'cell_type_ontology_term_id', 'Donor_Age_Group', 'tissue_type', 'cell_type', 'assay', 'disease', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'archetypes'
    var: 'HCA_Hay2018', 'Oetjen2018', 'Granja2019', 'Mende2022', 'Setty2019', 'Ainciburu2023', 'HVG_intersect3000', 'nCells_Detected', 'nDatasets_Detected', 'gene_symbols', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'highly_variable', 'means', 'dispersions

## Step 2: Compute Gene Associations

In [19]:
print("\nExporting results...")

# Save gene associations to CSV
output_file = "gene_associations.csv"
gene_assoc.to_csv(output_file, index=False)
print(f"  Saved: {output_file}")

# Summary statistics
print(f"\nSummary statistics:")
print(f"  Total genes tested: {gene_assoc['gene'].nunique()}")
print(f"  Archetypes analyzed: {gene_assoc['archetype'].nunique()}")
print(f"  Significant associations (fdr_p < {p_value_threshold}): {(gene_assoc['fdr_pvalue'] < p_value_threshold).sum()}")


Exporting results...
  Saved: gene_associations.csv

Summary statistics:
  Total genes tested: 1397
  Archetypes analyzed: 6
  Significant associations (fdr_p < 0.05): 2803


Note that this uses the 1-vs-all Wilcoxon rank sum test previously used in archetype analysis papers where each archetype is independently treated as the 'high' group against all other cells. Because this test rotates to each archetype, you may wind up with the same genes significantly associated with multiple archetypes (same for pathway associations below). This can lead to 3,000-4,000 significant genes per archetype, which is difficult to interpret, so you may want to limit the genes or pathways you are testing against to subsets of interest. I'm not particularly fond of this test, so future versions may include a more discriminative per-archetype characterization method.

## Step 3: Compute Pathway Associations (Optional)

In [20]:
print("\n" + "="*70)
print("WORKFLOW 05 COMPLETE")
print("="*70)
print(f"Gene associations computed:")
print(f"  • DataFrame with {len(gene_assoc):,} gene-archetype associations")
print(f"  • Columns: gene, archetype, log_fold_change, pvalue, fdr_pvalue, significant")
print(f"  • Exported to: {output_file}")
print(f"\nNext workflows:")
print(f"  • WORKFLOW_06: CellRank Integration (requires RNA velocity)")
print(f"  • WORKFLOW_08: Comprehensive Visualization")
print("="*70)


WORKFLOW 05 COMPLETE
Gene associations computed:
  • DataFrame with 6,693 gene-archetype associations
  • Columns: gene, archetype, log_fold_change, pvalue, fdr_pvalue, significant
  • Exported to: gene_associations.csv

Next workflows:
  • WORKFLOW_06: CellRank Integration (requires RNA velocity)
  • WORKFLOW_08: Comprehensive Visualization


See 08_visualization.ipynb for dotplot and filtering options for results interpretation.

## Step 4: Export Results

In [21]:
print("\nExporting results...")

# Save gene associations to CSV
output_file = "gene_associations.csv"
gene_assoc.to_csv(output_file, index=False)
print(f"  Saved: {output_file}")

# Summary statistics
print(f"\nSummary statistics:")
print(f"  Total genes tested: {gene_assoc['gene'].nunique()}")
print(f"  Archetypes analyzed: {gene_assoc['archetype'].nunique()}")
print(f"  Significant associations (fdr_p < {p_value_threshold}): {(gene_assoc['fdr_pvalue'] < p_value_threshold).sum()}")


Exporting results...
  Saved: gene_associations.csv

Summary statistics:
  Total genes tested: 1397
  Archetypes analyzed: 6
  Significant associations (fdr_p < 0.05): 2803


## Summary

In [22]:
print("\n" + "="*70)
print("WORKFLOW 05 COMPLETE")
print("="*70)
print(f"Gene associations computed:")
print(f"  • DataFrame with {len(gene_assoc):,} gene-archetype associations")
print(f"  • Columns: gene, archetype, log_fold_change, pvalue, fdr_pvalue, significant")
print(f"  • Exported to: {output_file}")
print(f"\nNext workflows:")
print(f"  • WORKFLOW_06: CellRank Integration (requires RNA velocity)")
print(f"  • WORKFLOW_08: Comprehensive Visualization")
print("="*70)


WORKFLOW 05 COMPLETE
Gene associations computed:
  • DataFrame with 6,693 gene-archetype associations
  • Columns: gene, archetype, log_fold_change, pvalue, fdr_pvalue, significant
  • Exported to: gene_associations.csv

Next workflows:
  • WORKFLOW_06: CellRank Integration (requires RNA velocity)
  • WORKFLOW_08: Comprehensive Visualization
