Skip to content

feat: add gpu-aware keras/pytorch training runtime#183

Merged
marcellodebernardi merged 7 commits intomainfrom
codex/gpu-training-runtime-port
Mar 2, 2026
Merged

feat: add gpu-aware keras/pytorch training runtime#183
marcellodebernardi merged 7 commits intomainfrom
codex/gpu-training-runtime-port

Conversation

@marcellodebernardi
Copy link
Copy Markdown
Contributor

This PR adds GPU-aware neural-network training runtime support, streaming parquet data loading for Keras/PyTorch templates, canonical task type propagation, and related predictor/training metadata handling. The aim is to improve reliability and consistency of neural-network training and inference behavior across CPU and GPU environments while keeping task semantics explicit. It also includes targeted unit tests for training command construction and parquet streaming utilities, plus follow-up adjustments to use the current Python interpreter and set default neural-network epochs to 10.

Testing

  • poetry run ruff check plexe/execution/training/local_runner.py tests/unit/execution/training/test_local_runner.py plexe/templates/inference/keras_predictor.py plexe/templates/inference/pytorch_predictor.py plexe/config.py plexe/templates/training/train_keras.py plexe/templates/training/train_pytorch.py plexe/tools/submission.py
  • poetry run pytest tests/unit/execution/training/test_local_runner.py tests/unit/utils/test_parquet_dataset.py tests/unit/test_models.py tests/unit/test_config.py tests/unit/test_submission_pytorch.py tests/unit/test_imports.py -v

Copilot AI review requested due to automatic review settings March 2, 2026 17:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces GPU-aware training runtime support (PyTorch DDP / Keras multi-GPU + mixed precision), streaming parquet data loaders for large datasets, and canonical task_type propagation so training, inference, and metadata stay consistent across environments.

Changes:

  • Add streaming parquet utilities (PyTorch IterableDataset + TF/Keras generators) and corresponding unit tests.
  • Update local training runner to detect GPUs, choose the right launcher (python vs torch.distributed.run), and pass task/mixed-precision/worker flags through to templates.
  • Propagate canonical task_type through workflow/retrain/submission and update predictors to use metadata-driven post-processing; add GPU Docker build target and adjust NN defaults/timeouts.

Reviewed changes

Copilot reviewed 26 out of 26 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/unit/utils/test_parquet_dataset.py New unit tests for parquet streaming utilities and dataset sharding.
tests/unit/execution/training/test_local_runner.py New unit tests for GPU detection and training command construction.
tests/unit/execution/training/init.py Add package marker for test discovery/imports.
tests/unit/execution/init.py Add package marker for test discovery/imports.
tests/conftest.py Add synthetic parquet fixtures used by streaming tests.
tests/CODE_INDEX.md Update generated test code index to include new tests/fixtures.
plexe/workflow.py Pass task_type, mixed precision, dataloader worker settings; extend NN retrain timeout.
plexe/utils/parquet_dataset.py Add streaming parquet loaders + metadata helpers (row count/features/steps/size).
plexe/utils/dashboard/utils.py Switch dashboard row-count helper to metadata-based counting (no full read).
plexe/tools/submission.py Validate task_type against canonical enum values and update docstrings.
plexe/templates/training/train_xgboost.py Accept --task-type and write canonical task type into metadata.
plexe/templates/training/train_pytorch.py Rewrite to streaming parquet + optional DDP + mixed precision + richer metadata.
plexe/templates/training/train_lightgbm.py Accept --task-type and write canonical task type into metadata.
plexe/templates/training/train_keras.py Add streaming tf.data pipeline + optional multi-GPU + mixed precision + richer metadata.
plexe/templates/training/train_catboost.py Accept --task-type and write canonical task type into metadata.
plexe/templates/inference/pytorch_predictor.py Load metadata for task-type-driven post-processing; add predict_proba.
plexe/templates/inference/keras_predictor.py Load metadata for task-type-driven post-processing; add predict_proba.
plexe/retrain.py Pass task_type through to training runner for retrains.
plexe/models.py Add canonical TaskType enum for consistent task semantics.
plexe/execution/training/runner.py Extend runner interface to accept canonical task_type.
plexe/execution/training/local_runner.py Add GPU detection + torchrun launch for multi-GPU + pass task/mixed precision/worker flags.
plexe/config.py Add NN timeout/mixed precision/workers; set NN default epochs to 10.
plexe/CODE_INDEX.md Update generated package code index for new/changed APIs.
config.yaml.template Update NN default epochs example to 10.
Makefile Add build-gpu target for CUDA-capable image builds.
Dockerfile Add CPU/GPU base selection and GPU PyTorch install path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 2, 2026

Greptile Summary

This PR adds GPU-aware neural network training with streaming data loading, multi-GPU support (DDP for PyTorch, MirroredStrategy for Keras), and canonical task type propagation across the training pipeline.

Key improvements:

  • Streaming parquet data loading via ParquetIterableDataset handles 100GB+ datasets without OOM
  • Multi-GPU training: automatic torch.distributed.run for PyTorch DDP, MirroredStrategy for Keras
  • Mixed precision (FP16) support with auto-disable on CPU
  • Best model checkpointing to disk for PyTorch, EarlyStopping (patience=3) for Keras
  • Task type propagation eliminates re-inference in templates and ensures consistent predictor behavior
  • macOS spawn mode fallback prevents DataLoader worker hangs
  • Reduced default NN epochs from 25 to 10, added 4-hour timeout for full-dataset training
  • Comprehensive test coverage for GPU detection, command construction, and streaming dataset behavior

Architecture:

  • local_runner.py detects GPU count and constructs appropriate launch commands
  • Training templates read streaming data, handle distributed setup, and save task type to metadata
  • Predictors use task type from metadata for post-processing (sigmoid/argmax) instead of loss function inspection
  • Docker GPU variant with CUDA 12.9 runtime

The implementation is well-tested, follows established patterns, and maintains backward compatibility with CPU training.

Confidence Score: 5/5

  • This PR is safe to merge with high confidence
  • The implementation is well-architected with comprehensive test coverage across GPU detection, command construction, streaming data loading, and edge cases (macOS spawn mode). The streaming dataset design properly handles DDP sharding, worker assignment, and task type-specific dtype handling. Changes follow repository patterns and include proper error handling, fallbacks, and backward compatibility.
  • No files require special attention

Important Files Changed

Filename Overview
plexe/execution/training/local_runner.py Adds GPU detection, multi-GPU command construction (torchrun for PyTorch DDP, MirroredStrategy for Keras), task type propagation, and mixed precision flags to training runner
plexe/templates/training/train_pytorch.py Replaced in-memory data loading with streaming ParquetIterableDataset, added DDP support, mixed precision (FP16), best checkpoint tracking, and task type-aware label handling
plexe/templates/training/train_keras.py Replaced in-memory loading with streaming tf.data.Dataset, added MirroredStrategy for multi-GPU, mixed precision, EarlyStopping with patience=3, and task type propagation
plexe/utils/parquet_dataset.py New streaming parquet dataset implementation with PyTorch IterableDataset, DDP rank sharding, DataLoader worker sharding, and task type-aware dtype handling
plexe/templates/inference/pytorch_predictor.py Added task type-aware post-processing (sigmoid for binary, argmax for multiclass) driven by metadata instead of loss function inspection
plexe/templates/inference/keras_predictor.py Added task type-aware post-processing from metadata (threshold for binary, argmax for multiclass) for consistent predictions
plexe/config.py Changed default NN epochs from 25 to 10, added nn_training_timeout (14400s), mixed_precision (default true), and dataloader_workers (default 4) config fields
plexe/workflow.py Passes task_type, mixed_precision, and dataloader_workers to training runners; uses longer nn_training_timeout for full-dataset neural network retraining

Last reviewed commit: 0536028

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

28 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@marcellodebernardi
Copy link
Copy Markdown
Contributor Author

@greptileai please review again with latest changes

@marcellodebernardi marcellodebernardi force-pushed the codex/gpu-training-runtime-port branch from fe2d57b to 4aa0d89 Compare March 2, 2026 21:25
@marcellodebernardi
Copy link
Copy Markdown
Contributor Author

@greptileai please review again with latest changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 2, 2026

Additional Comments (3)

plexe/templates/training/train_keras.py, line 1025
EarlyStopping patience renders callback inert at default epoch count

patience=10 means the callback only halts training if validation loss fails to improve for 10 consecutive epochs. But nn_default_epochs is now also 10, so with default settings the loop exhausts all epochs before patience can ever fire — the callback is a no-op unless the user explicitly raises epochs well above 10.

Either bring patience in line with the new default, for example:

        keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True, verbose=1),

…or document that EarlyStopping is only meaningful when the caller passes epochs > patience (e.g. via nn_max_epochs=50).


plexe/templates/training/train_catboost.py, line 757
Fallback task_type misclassifies multiclass CatBoost models as binary

When task_type is not supplied by the caller, the fallback is:

task_type = "binary_classification" if is_classification else "regression"

But CatBoostClassifier can be trained on problems with more than two classes. In that case the saved metadata will incorrectly record "binary_classification", which could cause downstream predictors to threshold at 0.5 instead of performing an argmax.

The same issue exists in train_xgboost.py (same fallback pattern, line ~1714 in the diff).

A safer fallback would inspect the actual number of classes:

if is_classification:
    n_classes = len(set(y_train.tolist()))
    task_type = "multiclass_classification" if n_classes > 2 else "binary_classification"
else:
    task_type = "regression"

plexe/templates/training/train_pytorch.py, line 1635
Metadata epochs records configured value, not actual epochs run

The Keras template distinguishes between epochs (actual epochs run) and max_epochs (configured ceiling), and this distinction is useful for retrain.py and downstream analysis. The PyTorch metadata only records the configured epochs, losing information about whether training finished early (e.g. due to an exception mid-run or future early-stopping logic).

Consider mirroring the Keras pattern:

            "epochs": epoch + 1,  # actual epochs completed (epoch is last completed index)
            "max_epochs": epochs,

(where epoch is captured after the training loop via actual_epochs = len(history["train_loss"]))

@marcellodebernardi
Copy link
Copy Markdown
Contributor Author

Implemented a minimal follow-up addressing the latest Greptile concerns:

  • Reduced Keras EarlyStopping patience from 10 to 3 so it is effective with the default epochs=10.
  • Hardened fallback task_type inference for classifier templates when --task-type is missing:
    • CatBoost, XGBoost, and LightGBM now infer multiclass_classification when class count is > 2, otherwise binary_classification.
  • Kept PyTorch metadata epochs semantics unchanged in this PR to avoid retrain behavior drift; we can add a non-breaking actual_epochs field in a dedicated follow-up if needed.

@greptileai please review again with latest changes

@marcellodebernardi marcellodebernardi merged commit d5032d0 into main Mar 2, 2026
13 checks passed
@marcellodebernardi marcellodebernardi deleted the codex/gpu-training-runtime-port branch March 2, 2026 22:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants