-
Notifications
You must be signed in to change notification settings - Fork 0
Rayless sft val loss #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds validation loss calculation capability to the SFT training script, enabling periodic evaluation on a separate validation dataset during training.
Key Changes:
- Introduces validation data source initialization and periodic validation loss calculation
- Adds
--val-prompt-data,--val-interval, and--val-stepscommand-line arguments - Implements
calculate_val_loss()and_val_step()methods for validation execution
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 11 comments.
| File | Description |
|---|---|
| train_sft.py | Adds validation loss calculation methods, updates data source initialization to support separate validation data, and integrates periodic validation into the training loop |
| miles/rollout/data_source.py | Modifies RolloutDataSource constructor to accept optional prompt_data parameter, enabling reuse for both training and validation datasets |
| miles/utils/arguments.py | Adds three new command-line arguments for configuring validation: data path, interval, and number of steps |
| scripts/run-sft-torchrun.sh | Provides example usage of the new validation arguments with validation data path and configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
scripts/run-sft-torchrun.sh
Outdated
| SFT_ARGS=( | ||
| --rollout-function-path miles.rollout.sft_rollout.generate_rollout | ||
| --prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens.parquet | ||
| --val-prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens_validation.parquet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use jsonl for now.
Co-authored-by: Franz Srambical <79149449+emergenz@users.noreply.github.com>
Co-authored-by: Franz Srambical <79149449+emergenz@users.noreply.github.com>
miles/utils/arguments.py
Outdated
| type=str, | ||
| default=None, | ||
| help=( | ||
| "The path to the prompt data. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah shit sorry, I didn't see this. In that case, let's add The path to the validation prompt data. below.
scripts/run-sft-torchrun.sh
Outdated
| --prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens.parquet | ||
| --val-prompt-data /fast/project/HFMI_SynergyUnit/tab_model/huggingface/nemo_hf_part_jsonl_4k_tokens_validation.parquet | ||
| --val-interval 100 | ||
| --val-steps 50 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a lot of steps for an interval of 100.
No description provided.