feat: add gpu-aware keras/pytorch training runtime#183
feat: add gpu-aware keras/pytorch training runtime#183marcellodebernardi merged 7 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces GPU-aware training runtime support (PyTorch DDP / Keras multi-GPU + mixed precision), streaming parquet data loaders for large datasets, and canonical task_type propagation so training, inference, and metadata stay consistent across environments.
Changes:
- Add streaming parquet utilities (PyTorch
IterableDataset+ TF/Keras generators) and corresponding unit tests. - Update local training runner to detect GPUs, choose the right launcher (python vs
torch.distributed.run), and pass task/mixed-precision/worker flags through to templates. - Propagate canonical
task_typethrough workflow/retrain/submission and update predictors to use metadata-driven post-processing; add GPU Docker build target and adjust NN defaults/timeouts.
Reviewed changes
Copilot reviewed 26 out of 26 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/utils/test_parquet_dataset.py | New unit tests for parquet streaming utilities and dataset sharding. |
| tests/unit/execution/training/test_local_runner.py | New unit tests for GPU detection and training command construction. |
| tests/unit/execution/training/init.py | Add package marker for test discovery/imports. |
| tests/unit/execution/init.py | Add package marker for test discovery/imports. |
| tests/conftest.py | Add synthetic parquet fixtures used by streaming tests. |
| tests/CODE_INDEX.md | Update generated test code index to include new tests/fixtures. |
| plexe/workflow.py | Pass task_type, mixed precision, dataloader worker settings; extend NN retrain timeout. |
| plexe/utils/parquet_dataset.py | Add streaming parquet loaders + metadata helpers (row count/features/steps/size). |
| plexe/utils/dashboard/utils.py | Switch dashboard row-count helper to metadata-based counting (no full read). |
| plexe/tools/submission.py | Validate task_type against canonical enum values and update docstrings. |
| plexe/templates/training/train_xgboost.py | Accept --task-type and write canonical task type into metadata. |
| plexe/templates/training/train_pytorch.py | Rewrite to streaming parquet + optional DDP + mixed precision + richer metadata. |
| plexe/templates/training/train_lightgbm.py | Accept --task-type and write canonical task type into metadata. |
| plexe/templates/training/train_keras.py | Add streaming tf.data pipeline + optional multi-GPU + mixed precision + richer metadata. |
| plexe/templates/training/train_catboost.py | Accept --task-type and write canonical task type into metadata. |
| plexe/templates/inference/pytorch_predictor.py | Load metadata for task-type-driven post-processing; add predict_proba. |
| plexe/templates/inference/keras_predictor.py | Load metadata for task-type-driven post-processing; add predict_proba. |
| plexe/retrain.py | Pass task_type through to training runner for retrains. |
| plexe/models.py | Add canonical TaskType enum for consistent task semantics. |
| plexe/execution/training/runner.py | Extend runner interface to accept canonical task_type. |
| plexe/execution/training/local_runner.py | Add GPU detection + torchrun launch for multi-GPU + pass task/mixed precision/worker flags. |
| plexe/config.py | Add NN timeout/mixed precision/workers; set NN default epochs to 10. |
| plexe/CODE_INDEX.md | Update generated package code index for new/changed APIs. |
| config.yaml.template | Update NN default epochs example to 10. |
| Makefile | Add build-gpu target for CUDA-capable image builds. |
| Dockerfile | Add CPU/GPU base selection and GPU PyTorch install path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Greptile SummaryThis PR adds GPU-aware neural network training with streaming data loading, multi-GPU support (DDP for PyTorch, MirroredStrategy for Keras), and canonical task type propagation across the training pipeline. Key improvements:
Architecture:
The implementation is well-tested, follows established patterns, and maintains backward compatibility with CPU training. Confidence Score: 5/5
|
| Filename | Overview |
|---|---|
| plexe/execution/training/local_runner.py | Adds GPU detection, multi-GPU command construction (torchrun for PyTorch DDP, MirroredStrategy for Keras), task type propagation, and mixed precision flags to training runner |
| plexe/templates/training/train_pytorch.py | Replaced in-memory data loading with streaming ParquetIterableDataset, added DDP support, mixed precision (FP16), best checkpoint tracking, and task type-aware label handling |
| plexe/templates/training/train_keras.py | Replaced in-memory loading with streaming tf.data.Dataset, added MirroredStrategy for multi-GPU, mixed precision, EarlyStopping with patience=3, and task type propagation |
| plexe/utils/parquet_dataset.py | New streaming parquet dataset implementation with PyTorch IterableDataset, DDP rank sharding, DataLoader worker sharding, and task type-aware dtype handling |
| plexe/templates/inference/pytorch_predictor.py | Added task type-aware post-processing (sigmoid for binary, argmax for multiclass) driven by metadata instead of loss function inspection |
| plexe/templates/inference/keras_predictor.py | Added task type-aware post-processing from metadata (threshold for binary, argmax for multiclass) for consistent predictions |
| plexe/config.py | Changed default NN epochs from 25 to 10, added nn_training_timeout (14400s), mixed_precision (default true), and dataloader_workers (default 4) config fields |
| plexe/workflow.py | Passes task_type, mixed_precision, and dataloader_workers to training runners; uses longer nn_training_timeout for full-dataset neural network retraining |
Last reviewed commit: 0536028
|
@greptileai please review again with latest changes |
fe2d57b to
4aa0d89
Compare
|
@greptileai please review again with latest changes |
Additional Comments (3)
Either bring patience in line with the new default, for example: …or document that
When task_type = "binary_classification" if is_classification else "regression"But The same issue exists in A safer fallback would inspect the actual number of classes: if is_classification:
n_classes = len(set(y_train.tolist()))
task_type = "multiclass_classification" if n_classes > 2 else "binary_classification"
else:
task_type = "regression"
The Keras template distinguishes between Consider mirroring the Keras pattern: (where |
|
Implemented a minimal follow-up addressing the latest Greptile concerns:
@greptileai please review again with latest changes |
This PR adds GPU-aware neural-network training runtime support, streaming parquet data loading for Keras/PyTorch templates, canonical task type propagation, and related predictor/training metadata handling. The aim is to improve reliability and consistency of neural-network training and inference behavior across CPU and GPU environments while keeping task semantics explicit. It also includes targeted unit tests for training command construction and parquet streaming utilities, plus follow-up adjustments to use the current Python interpreter and set default neural-network epochs to 10.
Testing
poetry run ruff check plexe/execution/training/local_runner.py tests/unit/execution/training/test_local_runner.py plexe/templates/inference/keras_predictor.py plexe/templates/inference/pytorch_predictor.py plexe/config.py plexe/templates/training/train_keras.py plexe/templates/training/train_pytorch.py plexe/tools/submission.pypoetry run pytest tests/unit/execution/training/test_local_runner.py tests/unit/utils/test_parquet_dataset.py tests/unit/test_models.py tests/unit/test_config.py tests/unit/test_submission_pytorch.py tests/unit/test_imports.py -v