# Orchard ML: Optuna Model Search on Galaxy10 (GPU)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tomrussobuilds/orchard-ml/blob/main/notebooks/02_galaxy10_optuna_model_search.ipynb)

> **Runtime**: This notebook requires a **GPU runtime**. In Colab: `Runtime > Change runtime type > T4 GPU`

This notebook demonstrates Orchard ML's **automatic hyperparameter optimization with model search** using Optuna:

- **Dataset**: [Galaxy10 DECals](https://zenodo.org/records/10845026) (224x224 RGB, 10 galaxy morphology classes)
- **Model search**: Optuna explores EfficientNet-B0, ConvNeXt-Tiny, ViT-Tiny, and ResNet-18 automatically
- **Time**: ~30-45 minutes on Colab T4 GPU

### What you'll learn
1. How to configure Optuna hyperparameter search with `enable_model_search: true`
2. How Orchard ML runs optimization trials, then trains with the best config
3. How to interpret optimization results (parameter importances, trial history)

## 1. Setup

In [None]:
import os

%cd /content
if not os.path.isdir("orchard-ml"):
    !git clone --depth 1 https://github.com/tomrussobuilds/orchard-ml.git

%cd /content/orchard-ml
!git pull --ff-only
%pip install -q -r requirements.txt

In [None]:
# Verify GPU is available
import torch

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    raise RuntimeError(
        "GPU not detected. This notebook requires a GPU runtime.\n"
        "In Colab: Runtime > Change runtime type > T4 GPU > Save\n"
        "If already set to T4, Colab may be out of GPU quota — try again later."
    )

## 2. Configuration

We create a custom YAML config that enables **model search** — Optuna will explore different architectures
(EfficientNet-B0, ViT-Tiny, ConvNeXt-Tiny, ResNet-18) alongside hyperparameters like learning rate, dropout, and augmentation.

Key settings for Colab:
- `n_trials: 5` — enough to explore the search space without exceeding runtime limits
- `epochs: 5` — short optimization trials (final training uses full 20 epochs)
- `enable_model_search: true` — lets Optuna pick the best architecture

In [None]:
%%writefile colab_galaxy10_search.yaml
# Galaxy10 Model Search — Colab-optimized config

dataset:
  name: "galaxy10"
  data_root: ./dataset
  resolution: 224
  force_rgb: true
  use_weighted_sampler: false

architecture:
  name: "efficientnet_b0"        # Default; Optuna will override with model search
  pretrained: true
  dropout: 0.3

training:
  seed: 42
  batch_size: 16
  learning_rate: 0.0001
  weight_decay: 0.0005
  momentum: 0.9
  min_lr: 1e-7
  mixup_alpha: 0.0
  label_smoothing: 0.0
  epochs: 20                      # Final training after optimization
  patience: 10
  grad_clip: 1.0
  mixup_epochs: 0
  scheduler_type: "cosine"
  cosine_fraction: 0.5
  scheduler_patience: 5
  scheduler_factor: 0.1
  step_size: 20
  use_amp: true
  use_tta: false
  criterion_type: "cross_entropy"
  weighted_loss: false
  focal_gamma: 2.0

augmentation:
  hflip: 0.3
  rotation_angle: 3
  jitter_val: 0.05
  min_scale: 0.97
  tta_translate: 0.5
  tta_scale: 1.02
  tta_blur_sigma: 0.1

hardware:
  device: "auto"

telemetry:
  output_dir: ./outputs
  log_level: "INFO"
  log_interval: 50

evaluation:
  batch_size: 32
  n_samples: 12
  fig_dpi: 150
  cmap_confusion: Blues
  plot_style: seaborn-v0_8-muted
  grid_cols: 4
  fig_size_predictions: [12, 8]
  report_format: xlsx
  save_confusion_matrix: true
  save_predictions_grid: true

tracking:
  enabled: false

optuna:
  study_name: "galaxy10_model_search_colab"
  n_trials: 5
  epochs: 5                       # Short trials for speed
  timeout: null
  metric_name: "auc"
  direction: "maximize"
  enable_early_stopping: true
  early_stopping_threshold: 0.9999
  early_stopping_patience: 2
  sampler_type: "tpe"
  search_space_preset: "full"
  enable_model_search: true       # Explore architectures automatically
  enable_pruning: true
  pruner_type: "median"
  pruning_warmup_epochs: 3
  storage_type: "sqlite"
  storage_path: null
  n_jobs: 1
  load_if_exists: true
  show_progress_bar: false
  save_plots: true
  save_best_config: true

export:
  format: onnx
  opset_version: 18
  validate_export: true

## 3. Run Optimization + Training

Orchard ML automatically executes the full pipeline:
1. **Optimization** — 5 Optuna trials, each training for 5 epochs. Optuna explores different architectures and hyperparameters.
2. **Training** — Full 20-epoch training using the best configuration found.
3. **Export** — ONNX model export for deployment.

In [None]:
!python forge.py --config colab_galaxy10_search.yaml

## 4. Explore Results

### 4.1 Optimization artifacts

Optuna generates interactive HTML plots and a `best_config.yaml` with the winning hyperparameters.

In [None]:
import glob
import os

# Find the latest run directory
run_dirs = sorted(glob.glob("outputs/*/"))
latest_run = run_dirs[-1]
print(f"Run directory: {latest_run}\n")

# List all generated artifacts
for root, dirs, files in os.walk(latest_run):
    level = root.replace(latest_run, "").count(os.sep)
    indent = "  " * level
    print(f"{indent}{os.path.basename(root)}/")
    sub_indent = "  " * (level + 1)
    for file in sorted(files):
        size = os.path.getsize(os.path.join(root, file))
        print(f"{sub_indent}{file} ({size / 1024:.1f} KB)")

In [None]:
import yaml

# Show the best config found by Optuna
best_configs = glob.glob(f"{latest_run}/reports/best_config*.yaml")
if best_configs:
    with open(best_configs[0]) as f:
        best_cfg = yaml.safe_load(f)
    print("Best configuration found by Optuna:")
    print(yaml.dump(best_cfg, default_flow_style=False, sort_keys=False))
else:
    print("best_config.yaml not found (check outputs directory)")

### 4.2 Optuna visualizations

Orchard ML generates interactive Plotly plots showing parameter importance and optimization history.

In [None]:
from IPython.display import display, HTML

# Display parameter importance plot
importance_files = glob.glob(f"{latest_run}/figures/param_importances*.html")
if importance_files:
    print("Parameter Importances (which hyperparameters matter most):")
    with open(importance_files[0]) as f:
        display(HTML(f.read()))

In [None]:
# Display optimization history
history_files = glob.glob(f"{latest_run}/figures/optimization_history*.html")
if history_files:
    print("Optimization History (AUC across trials):")
    with open(history_files[0]) as f:
        display(HTML(f.read()))

### 4.3 Final training results

After optimization, Orchard ML trains the best model for the full 20 epochs and generates evaluation artifacts.

In [None]:
from IPython.display import display, Image

# Display confusion matrix
cm_files = glob.glob(f"{latest_run}/figures/confusion_matrix*.png")
if cm_files:
    print("Confusion Matrix (final model):")
    display(Image(filename=cm_files[0], width=600))

In [None]:
# Display predictions grid
pred_files = glob.glob(
    f"{latest_run}/figures/sample_predictions*.png"
)
if pred_files:
    print("Sample Predictions (Galaxy morphology):")
    display(Image(filename=pred_files[0], width=700))

## 5. Next Steps

- **Scale up**: Increase `n_trials` to 20 and `epochs` to 15 for more thorough optimization (see `recipes/optuna_galaxy10_efficientnet_b0.yaml`)
- **CPU-friendly demo**: See [01_quickstart_bloodmnist_cpu.ipynb](./01_quickstart_bloodmnist_cpu.ipynb) for a quick intro without GPU
- **MedMNIST datasets**: Swap `galaxy10` with any of the MedMNIST datasets (e.g., `pathmnist`, `dermamnist`, `octmnist`)
- **Documentation**: Check the [Optimization Guide](https://github.com/tomrussobuilds/orchard-ml/blob/main/docs/guide/OPTIMIZATION.md) for advanced search space configuration