-
Notifications
You must be signed in to change notification settings - Fork 77
Add GRPO loss example #1063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GRPO loss example #1063
Conversation
jansel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Can you add a test for this similar to the tests for the other examples?
d3277a5 to
048ecf1
Compare
Yeah @jansel I added them just now with changes in both .expected file and the examples! |
|
@oulgen I have removed the redundant configs for a specific hardware. Is it ok to merge? |
ec54cdb to
97393e9
Compare
|
Looks like many of the tests are failing with: Can we change the test to run faster? Perhaps smaller inputs? |
Overview
This PR adds a complete example implementation of Group Relative Policy Optimization (GRPO) loss using Helion kernels, including:
Benchmarks
=== Timing (median ms) ===
PyTorch Forward: 5.286 ms
PyTorch Backward: 14.828 ms
Helion Forward: 1.053 ms (x5.02 vs Torch)
Helion Backward: 2.461 ms (x6.03 vs Torch)
=== Throughput ===
PyTorch Fwd tokens/s: 3099305.7
PyTorch Bwd tokens/s: 1104966.6
Helion Fwd tokens/s: 15564766.3
Helion Bwd tokens/s: 6658797.2
Motivation
GRPO is used to optimize RLHF-style training by stabilizing policy updates through clipping with optional KL regularization against a reference model. This example demonstrates how to express a numerically stable, fused GRPO loss in Helion, and how to integrate Helion kernels into an end-to-end PyTorch workflow.