WORKFLOW 04: Archetype Coordinates & Cell Assignment
=====================================================

This workflow demonstrates how to characterize cells in terms of archetypes:
1. Compute archetype coordinates (distances in PCA space)
2. Extract archetype weights (barycentric coordinates from model)
3. Assign cells to nearest archetypes (categorical labels)

CRITICAL DISTINCTION:
- archetype_distances: Euclidean distances from each cell to archetype positions in PCA space
- cell_archetype_weights: Barycentric coordinates from the trained model (sum to 1 per cell)

These can disagree! Distance-based and weight-based assignments may differ for ~60% of cells.
This is expected and reflects the difference between geometric and learned representations.

Example usage:
    python WORKFLOW_04.py

Requirements:
    - peach
    - scanpy
    - Trained model (from WORKFLOW_03)

In [9]:
import scanpy as sc
import peach as pc
from pathlib import Path

## Configuration

In [None]:
# Data path
data_path = Path("~/data/HSC.h5ad")

# Training parameters (from WORKFLOW_03)
n_archetypes = 5
hidden_dims = [256, 128, 64]
n_epochs = 50
seed = 42

## Step 1: Load Data and Train Model (Prerequisites)

In [11]:
print("Loading data and training model (prerequisite)...")
adata = sc.read_h5ad(data_path)
print(f"  Shape: {adata.n_obs:,} cells × {adata.n_vars:,} genes")

# Ensure PCA exists
if 'X_pca' not in adata.obsm:
    print("  Running PCA...")
    sc.tl.pca(adata, n_comps=13)

# Train archetypal model (required for downstream analyses)
print(f"  Training model ({n_archetypes} archetypes)...")
results = pc.tl.train_archetypal(
    adata,
    n_archetypes=n_archetypes,
    n_epochs=n_epochs,
    hidden_dims=hidden_dims,
    early_stopping_patience=10,
    seed=seed,
    device='cpu',
)
print(f"  Training complete!")

# Check that archetype coordinates were stored
if 'archetype_coordinates' in adata.uns:
    arch_coords = adata.uns['archetype_coordinates']
    print(f"  Archetype coordinates: {arch_coords.shape}")
else:
    print("  Warning: archetype_coordinates not found in adata.uns")

Loading data and training model (prerequisite)...
  Shape: 263,159 cells × 28,121 genes
  Training model (5 archetypes)...
[OK] Using specified PCA coordinates: adata.obsm['X_pca'] (263159, 50)
[STATS] DataLoader created: 263159 cells × 50 PCA components
   Config: batch_size=128, workers=0 (Apple Silicon)
Archetypes parameter registered: True
Archetypes requires_grad: True
Deep_AA (Deep Archetypal Analysis) initialized:
  - Single-stage architecture (like Deep_2)
  - Inflation factor: 1.5
  - Direct archetypal coordinates (no bottleneck)
 Initializing with PCHA + inflation_factor=1.5...

 Consolidated Archetype Initialization
   PCHA: True, Inflation: True (factor: 1.5)
   Test inflation: False
Running PCHA initialization...
  Input shape: (1000, 50)
  Target archetypes: 5
Running PCHA with 5 archetypes...
Data shape for PCHA: (50, 1000)
PCHA Results:
  Archetypes shape: (5, 50)
  Archetype R²: 0.4302
  SSE: 192140.1548
  PCHA archetype R²: 0.4302
  Archetype shape: (5, 50)
[OK] Initi

## Step 2: Compute Archetype Distances (Geometric)

In [12]:
print("\nComputing archetype distances...")
print("  (Euclidean distances in PCA space)")

pc.tl.archetypal_coordinates(adata)

# Access the created distances
distances = adata.obsm['archetype_distances']
print(f"  Created: adata.obsm['archetype_distances']")
print(f"  Shape: {distances.shape} (cells × archetypes)")
print(f"  Distance range: [{distances.min():.4f}, {distances.max():.4f}]")

# Find nearest archetype for each cell (distance-based)
nearest_by_distance = distances.argmin(axis=1)
print(f"\n  Distance-based assignment:")
print(f"    Archetype 0: {(nearest_by_distance == 0).sum():,} cells")
print(f"    Archetype 1: {(nearest_by_distance == 1).sum():,} cells")
print(f"    Archetype 2: {(nearest_by_distance == 2).sum():,} cells")
print(f"    Archetype 3: {(nearest_by_distance == 3).sum():,} cells")
print(f"    Archetype 4: {(nearest_by_distance == 4).sum():,} cells")


Computing archetype distances...
  (Euclidean distances in PCA space)
 Computing archetype distances in PCA space...
   Canonical reference: adata.obs.index (263159 cells)
   Found PCA coordinates: X_pca (263159, 50)
   Found archetype coordinates: archetype_coordinates (5, 50)
 Computing pairwise distances in PCA space...
[OK] Distance computation complete
   Distance matrix shape: (263159, 5)
[OK] Stored in AnnData:
   adata.obsm['archetype_distances']: (263159, 5) distance matrix
   adata.uns['archetype_positions']: (5, 50) archetype positions
   adata.uns['archetype_distance_info']: distance computation metadata

[STATS] Distance Statistics:
   Nearest archetype distribution:
      Archetype 0: 85171 cells (32.4%), mean distance: 14.9793
      Archetype 1: 120843 cells (45.9%), mean distance: 17.0471
      Archetype 2: 39120 cells (14.9%), mean distance: 17.1266
      Archetype 3: 13913 cells (5.3%), mean distance: 20.5289
      Archetype 4: 4112 cells (1.6%), mean distance: 19.55

## Step 3: Extract Archetype Weights (Model-based)

In [13]:
print("\nExtracting archetype weights...")
print("  (Barycentric coordinates from trained model)")

weights = pc.tl.extract_archetype_weights(adata)

print(f"  Returned: weights array")
print(f"  Shape: {weights.shape} (cells × archetypes)")
print(f"  Weights sum to 1 per cell: {weights.sum(axis=1).mean():.6f}")
print(f"  Weight range: [{weights.min():.4f}, {weights.max():.4f}]")

# Also stored in adata
if 'cell_archetype_weights' in adata.obsm:
    stored_weights = adata.obsm['cell_archetype_weights']
    print(f"  Also stored: adata.obsm['cell_archetype_weights']")
    print(f"  Shape: {stored_weights.shape}")

# Find dominant archetype for each cell (weight-based)
nearest_by_weight = weights.argmax(axis=1)
print(f"\n  Weight-based assignment:")
print(f"    Archetype 0: {(nearest_by_weight == 0).sum():,} cells")
print(f"    Archetype 1: {(nearest_by_weight == 1).sum():,} cells")
print(f"    Archetype 2: {(nearest_by_weight == 2).sum():,} cells")
print(f"    Archetype 3: {(nearest_by_weight == 3).sum():,} cells")
print(f"    Archetype 4: {(nearest_by_weight == 4).sum():,} cells")


Extracting archetype weights...
  (Barycentric coordinates from trained model)
[STATS] Using model from adata.uns['trained_model']
[STATS] Extracting weights for 263159 cells...
   PCA shape: (263159, 50)
   Device: cpu
   Batch size: 256
   Processed 32000/263159 cells...
   Processed 64000/263159 cells...
   Processed 96000/263159 cells...
   Processed 128000/263159 cells...
   Processed 160000/263159 cells...
   Processed 192000/263159 cells...
   Processed 224000/263159 cells...
   Processed 256000/263159 cells...
   Processed 263159/263159 cells...
[OK] Stored cell weights in adata.obsm['cell_archetype_weights']
   Shape: (263159, 5)
   Range: [0.000, 0.946]
   Mean sum: 1.0000

[STATS] Archetype weight statistics:
   Archetype 0: mean=0.221, std=0.151, max=0.854, dominant in 21378 cells
   Archetype 1: mean=0.309, std=0.228, max=0.946, dominant in 77577 cells
   Archetype 2: mean=0.245, std=0.196, max=0.778, dominant in 46663 cells
   Archetype 3: mean=0.050, std=0.054, max=0.42

## Step 4: Assign Cells to Archetypes (Categorical)

In [14]:
print("\nAssigning cells to archetypes...")
print("  (Creates categorical labels based on dominant weight)")

pc.tl.assign_archetypes(adata)

# Access the categorical assignments
if 'archetypes' in adata.obs.columns:
    assignments = adata.obs['archetypes']
    print(f"  Created: adata.obs['archetypes']")
    print(f"  Type: {assignments.dtype}")
    print(f"  Categories: {list(assignments.cat.categories)}")

    # Count cells per archetype
    print(f"\n  Cell counts per archetype:")
    for cat in assignments.cat.categories:
        count = (assignments == cat).sum()
        print(f"    {cat}: {count:,} cells")


Assigning cells to archetypes...
  (Creates categorical labels based on dominant weight)
 AnnData-centric archetype binning...
   Distance matrix: (263159, 5) (from adata.obsm['archetype_distances'])
   Canonical cell reference: adata.obs.index (263159 cells)
   Selecting top 26315 cells (10.0%) per archetype
   INCLUDING central archetype_0 (generalist cells)
   Archetype 0 (central): 26315 cells, centroid distance range: [27.2546, 29.0142], mean: 28.6545
   Archetype 1: 26315 cells, distance range: [7.8995, 12.2935], mean: 11.3832
   Archetype 2: 26315 cells, distance range: [9.0897, 13.6573], mean: 12.5763
   Archetype 3: 26315 cells, distance range: [7.3005, 18.5922], mean: 14.6817
   Archetype 4: 26315 cells, distance range: [10.6362, 28.9255], mean: 23.6167
   Archetype 5: 26315 cells, distance range: [6.6287, 43.9535], mean: 38.1339

[STATS] Assignment Summary:
   Total cells: 263159
   Archetype 0 (central): 26315 cells (10.0%)
   Archetype 1: 26315 cells (10.0%)
   Archetype 

## Step 5: Compare Distance vs Weight Assignments

In [15]:
print("\nComparing distance-based vs weight-based assignments...")

# How often do they agree?
agreement = (nearest_by_distance == nearest_by_weight).sum()
total = len(nearest_by_distance)
agreement_pct = agreement / total * 100

print(f"  Cells where distance and weight agree: {agreement:,} / {total:,} ({agreement_pct:.1f}%)")
print(f"  Cells where they disagree: {total - agreement:,} ({100-agreement_pct:.1f}%)")
print(f"\n  Note: Disagreement is EXPECTED (~40-60% of cells).")
print(f"        Distance is geometric, weights are learned.")


Comparing distance-based vs weight-based assignments...
  Cells where distance and weight agree: 3,344 / 263,159 (1.3%)
  Cells where they disagree: 259,815 (98.7%)

  Note: Disagreement is EXPECTED (~40-60% of cells).
        Distance is geometric, weights are learned.


## Summary

In [16]:
print("\n" + "="*70)
print("WORKFLOW 04 COMPLETE")
print("="*70)
print(f"Key outputs created:")
print(f"  • adata.obsm['archetype_distances'] - Euclidean distances")
print(f"  • adata.obsm['cell_archetype_weights'] - Barycentric coordinates")
print(f"  • adata.obs['archetypes'] - Categorical assignments")
print(f"\nNext workflows:")
print(f"  • WORKFLOW_05: Gene/Pathway Enrichment Analysis")
print(f"  • WORKFLOW_06: CellRank Integration (requires velocity)")
print(f"  • WORKFLOW_08: Visualization")
print("="*70)


WORKFLOW 04 COMPLETE
Key outputs created:
  • adata.obsm['archetype_distances'] - Euclidean distances
  • adata.obsm['cell_archetype_weights'] - Barycentric coordinates
  • adata.obs['archetypes'] - Categorical assignments

Next workflows:
  • WORKFLOW_05: Gene/Pathway Enrichment Analysis
  • WORKFLOW_06: CellRank Integration (requires velocity)
  • WORKFLOW_08: Visualization
