WORKFLOW 02: Hyperparameter Search with Cross-Validation
=========================================================

This workflow demonstrates how to find optimal hyperparameters for archetypal analysis:
1. Load preprocessed data (with PCA from WORKFLOW_01)
2. Configure search space for hyperparameters
3. Run cross-validation grid search
4. Analyze results and select best configuration

The hyperparameter search tests combinations of:
- Number of archetypes (discrete values)
- Hidden layer dimensions (architecture options)
- Inflation factor (PCA inflation for Deep AA)
- CV folds and training settings

Output structure (CVSummary):
- cv_summary.ranked_configs: List of configs ranked by performance
- cv_summary.summary_df: DataFrame with all results
- cv_summary.config_results: Detailed per-config results

Example usage:
    python WORKFLOW_02.py

Requirements:
    - peach
    - scanpy
    - Data with PCA (from WORKFLOW_01 or equivalent)

In [None]:
import scanpy as sc
import peach as pc
from pathlib import Path

## Configuration

In [None]:
# Data path - should have PCA already computed
data_path = Path("data/HSC.h5ad")

# Hyperparameter search space
n_archetypes_range = [7, 9, 11]  # Number of archetypes to test
hidden_dims_options = [
    [128, 64],           # Simpler architecture
    [256, 128, 64],      # Deeper architecture
]
inflation_factor_range = [1.0, 1.5]  # PCA inflation factor

# Cross-validation settings
cv_folds = 3                    # Number of CV folds
max_epochs_cv = 50              # Max epochs per fold
early_stopping_patience = 5     # Stop if no improvement
speed_preset = 'fast'           # 'fast', 'balanced', or 'thorough'
subsample_fraction = 0.5        # Use 50% of data for CV
max_cells_cv = 5000             # Maximum cells per CV run
random_state = 42               # For reproducibility

## Step 1: Load Data with PCA

In [None]:
print("Loading data...")
adata = sc.read_h5ad(data_path)
print(f"  Shape: {adata.n_obs:,} cells Ã— {adata.n_vars:,} genes")

# Ensure PCA exists (required for archetypal analysis)
if 'X_pca' not in adata.obsm:
    print("  Running PCA (required for archetypal analysis)...")
    sc.tl.pca(adata, n_comps=13)
    print(f"  PCA computed: {adata.obsm['X_pca'].shape}")
else:
    print(f"  PCA found: {adata.obsm['X_pca'].shape}")

**NB:** I've found that with archetype analysis you generally get best results when you use a PCA with the smallest n_components with >99% variance explained. The extraneous very low loading PCs just add noise to the archetypal training processes. For many datasets I have gotten best results with 5-11 PCs. Use Scanpy's pl.pca_variance_ratio() to explore.

## Step 2: Run Hyperparameter Search

In [4]:
print("\nRunning hyperparameter search...")
print(f"  Configurations to test:")
print(f"    n_archetypes: {n_archetypes_range}")
print(f"    hidden_dims: {len(hidden_dims_options)} options")
print(f"    inflation_factor: {inflation_factor_range}")
print(f"  Total combinations: {len(n_archetypes_range) * len(hidden_dims_options) * len(inflation_factor_range)}")
print(f"  CV folds: {cv_folds}")
print(f"  This may take several minutes...")

cv_summary = pc.tl.hyperparameter_search(
    adata,
    n_archetypes_range=n_archetypes_range,
    hidden_dims_options=hidden_dims_options,
    inflation_factor_range=inflation_factor_range,
    cv_folds=cv_folds,
    max_epochs_cv=max_epochs_cv,
    early_stopping_patience=early_stopping_patience,
    speed_preset=speed_preset,
    subsample_fraction=subsample_fraction,
    max_cells_cv=max_cells_cv,
    random_state=random_state,
    device='cpu',  # Use 'cuda' if GPU available
)

print("  Hyperparameter search complete!")


Epoch 1/1
Average loss: 4.2398
Archetypal loss: 4.2398
KLD loss: 30.8061
Reconstruction loss: 4.2398
Archetype RÂ²: 0.3888
fc_Y grad norm: 0.000000

TRAINING COMPLETED

Final Performance:
  loss: 4.2398 (range: 4.2398 - 4.2398)
  archetypal_loss: 4.2398 (range: 4.2398 - 4.2398)
  archetype_r2: 0.3888 (range: 0.3888 - 0.3888)

fc_Y Learning Summary:
Starting training for 1 epochs...
Device: cpu
Archetypal weight: 1.0, KLD weight: 0.0, Reconstruction weight: 0.0
  (Model configured: arch=1.0, kld=0.0)
Tracking stability: False, Validating constraints: False

Epoch 0 Debug:
z row sums (should be ~1.0): 1.0000 Â± 0.0000
z stats: min=0.0000, max=0.9969, mean=0.1111
Batch reconstruction MSE: 4.3782
Archetype stats: min=-11.0884, max=18.2340
Archetype change since last debug: 0.233413
Archetype gradients: norm=0.023383, mean=0.000720

Epoch 1/1
Average loss: 4.1372
Archetypal loss: 4.1372
KLD loss: 24.3915
Reconstruction loss: 4.1372
Archetype RÂ²: 0.4038
fc_Y grad norm: 0.000000

TRAINING C

Parameters passed here are default, you can change these or other parameters (e.g., 'archetypal_loss') as needed. 

## Step 3: Analyze Results

In [5]:
print("\nAnalyzing results...")

# Access ranked configurations (best to worst)
ranked_configs = cv_summary.ranked_configs

print(f"\nTop 3 configurations:")
for i, config in enumerate(ranked_configs[:3], 1):
    print(f"\n  {i}. Configuration:")
    print(f"     Performance (RÂ²): {config['metric_value']:.4f} Â± {config['std_error']:.4f}")
    print(f"     Settings: {config['config_summary']}")
    # Access hyperparameters dict
    hparams = config['hyperparameters']
    print(f"     Details:")
    print(f"       - n_archetypes: {hparams['n_archetypes']}")
    print(f"       - hidden_dims: {hparams['hidden_dims']}")
    print(f"       - inflation_factor: {hparams['inflation_factor']}")

# Get best configuration
best_config = ranked_configs[0]
print(f"\nRecommended configuration:")
print(f"  n_archetypes = {best_config['hyperparameters']['n_archetypes']}")
print(f"  hidden_dims = {best_config['hyperparameters']['hidden_dims']}")
print(f"  inflation_factor = {best_config['hyperparameters']['inflation_factor']}")
print(f"  Expected RÂ² = {best_config['metric_value']:.4f}")


Analyzing results...

Top 3 configurations:

  1. Configuration:
     Performance (RÂ²): 0.6089 Â± 0.0057
     Settings: 11 archetypes, [256, 128, 64] hidden dims, Î»=1.5
     Details:
       - n_archetypes: 11
       - hidden_dims: [256, 128, 64]
       - inflation_factor: 1.5

  2. Configuration:
     Performance (RÂ²): 0.6027 Â± 0.0095
     Settings: 11 archetypes, [128, 64] hidden dims, Î»=1.5
     Details:
       - n_archetypes: 11
       - hidden_dims: [128, 64]
       - inflation_factor: 1.5

  3. Configuration:
     Performance (RÂ²): 0.5921 Â± 0.0013
     Settings: 9 archetypes, [128, 64] hidden dims, Î»=1.5
     Details:
       - n_archetypes: 9
       - hidden_dims: [128, 64]
       - inflation_factor: 1.5

Recommended configuration:
  n_archetypes = 11
  hidden_dims = [256, 128, 64]
  inflation_factor = 1.5
  Expected RÂ² = 0.6089


Using 'best_config = ranked_configs[0]' returns the config that delivers the highest $R^2$, this is a useful way to programmatically access the best config if you're running this analysis in a script (e.g., part of a Snakemake workflow), but you can also visually inspect it to select the best configuration for your dataset using the elbow_curve() method below.

## Step 4: Visualize Results with Elbow Plot

The `pc.pl.elbow_curve()` function creates an interactive visualization showing:
- Performance metrics (RÂ², RMSE) across different numbers of archetypes
- Results for all tested hyperparameter combinations (hidden_dims, inflation_factor)
- Error bars from cross-validation folds

In [6]:
print("\nGenerating elbow plot using pc.pl.elbow_curve()...")

# Use PEACH's built-in elbow curve visualization
# Shows multiple metrics across all hyperparameter configurations
fig = pc.pl.elbow_curve(
    cv_summary,
    metrics=["archetype_r2", "mean_val_rmse"],  # Show both RÂ² and RMSE
)

# Display interactive plot
fig.show()

# Display the summary DataFrame to see all hyperparameters explored
print("\nFull hyperparameter search results:")
print(cv_summary.summary_df.to_string())


Generating elbow plot using pc.pl.elbow_curve()...



Full hyperparameter search results:
    n_archetypes     hidden_dims  inflation_factor  use_pcha_init  use_inflation  mean_convergence_epoch  mean_val_archetype_r2  mean_val_mae  mean_archetype_r2  mean_val_rmse  mean_early_stopped  std_convergence_epoch  std_val_archetype_r2  std_val_mae  std_archetype_r2  std_val_rmse  std_early_stopped  training_time  early_stopping_rate
0              7       [128, 64]               1.0           True          False                    25.0               0.523364      1.223734           0.523364       1.800805                 0.0                    0.0              0.006811     0.018873          0.006811      0.026436                0.0      14.967988                  0.0
1              7       [128, 64]               1.5           True           True                    25.0               0.525295      1.216038           0.525295       1.797145                 0.0                    0.0              0.002924     0.003426          0.002924      0.01

## Summary

In [7]:
print("\n" + "="*70)
print("WORKFLOW 02 COMPLETE")
print("="*70)
print(f"Configurations tested: {len(ranked_configs)}")
print(f"Best performance: RÂ² = {best_config['metric_value']:.4f}")
print(f"\nBest hyperparameters:")
print(f"  â€¢ n_archetypes: {best_config['hyperparameters']['n_archetypes']}")
print(f"  â€¢ hidden_dims: {best_config['hyperparameters']['hidden_dims']}")
print(f"  â€¢ inflation_factor: {best_config['hyperparameters']['inflation_factor']}")
print(f"\nKey outputs:")
print(f"  â€¢ cv_summary.ranked_configs - Ranked configurations")
print(f"  â€¢ cv_summary.summary_df - Full results DataFrame")
print(f"  â€¢ pc.pl.elbow_curve() - Interactive visualization")
print("\nNext workflow: WORKFLOW_03 (Model Training with best config)")
print("="*70)


WORKFLOW 02 COMPLETE
Configurations tested: 12
Best performance: RÂ² = 0.6089

Best hyperparameters:
  â€¢ n_archetypes: 11
  â€¢ hidden_dims: [256, 128, 64]
  â€¢ inflation_factor: 1.5

Key outputs:
  â€¢ cv_summary.ranked_configs - Ranked configurations
  â€¢ cv_summary.summary_df - Full results DataFrame
  â€¢ pc.pl.elbow_curve() - Interactive visualization

Next workflow: WORKFLOW_03 (Model Training with best config)
