Skip to content

Conversation

@ighoshsubho
Copy link
Contributor

Overview

This PR adds a complete example implementation of Group Relative Policy Optimization (GRPO) loss using Helion kernels, including:

  • Helion kernels for forward and backward passes
  • PyTorch reference implementation for validation
  • Autograd integration via a custom torch.autograd.Function
  • Unit-style verification and a micro-benchmark harness

Benchmarks

=== Timing (median ms) ===
PyTorch Forward: 5.286 ms
PyTorch Backward: 14.828 ms
Helion Forward: 1.053 ms (x5.02 vs Torch)
Helion Backward: 2.461 ms (x6.03 vs Torch)

=== Throughput ===
PyTorch Fwd tokens/s: 3099305.7
PyTorch Bwd tokens/s: 1104966.6
Helion Fwd tokens/s: 15564766.3
Helion Bwd tokens/s: 6658797.2

Motivation

GRPO is used to optimize RLHF-style training by stabilizing policy updates through clipping with optional KL regularization against a reference model. This example demonstrates how to express a numerically stable, fused GRPO loss in Helion, and how to integrate Helion kernels into an end-to-end PyTorch workflow.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 31, 2025
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Can you add a test for this similar to the tests for the other examples?

@ighoshsubho ighoshsubho force-pushed the add-fused-grpo-loss-example branch from d3277a5 to 048ecf1 Compare November 2, 2025 07:29
@ighoshsubho
Copy link
Contributor Author

Thanks for the contribution! Can you add a test for this similar to the tests for the other examples?

Yeah @jansel I added them just now with changes in both .expected file and the examples!

@ighoshsubho
Copy link
Contributor Author

@oulgen I have removed the redundant configs for a specific hardware. Is it ok to merge?

@oulgen oulgen force-pushed the add-fused-grpo-loss-example branch from ec54cdb to 97393e9 Compare November 3, 2025 19:28
@jansel
Copy link
Contributor

jansel commented Nov 4, 2025

Looks like many of the tests are failing with:
FAILED test/test_examples.py::TestExamples::test_grpo_loss_bwd - Failed: Timeout (>60.0s) from pytest-timeout.

Can we change the test to run faster? Perhaps smaller inputs?

@ighoshsubho
Copy link
Contributor Author

ighoshsubho commented Nov 4, 2025

I have made the tensor shape small enough to fit under 60 sec runtime compilation, should not throw any error for backward now.

cc: @jansel @oulgen

@oulgen
Copy link
Contributor

oulgen commented Nov 4, 2025

I have made the tensor shape small enough to fit under 60 sec runtime compilation, should not throw any error for backward now.

cc: @jansel @oulgen

thanks, i'll merge if/when the tests pass

@oulgen oulgen merged commit 5ef76af into pytorch:main Nov 4, 2025
14 of 15 checks passed
@ighoshsubho ighoshsubho deleted the add-fused-grpo-loss-example branch November 4, 2025 10:32
tianrengao pushed a commit that referenced this pull request Nov 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants