diff --git a/examples/README.rst b/examples/README.rst index 2b04bcd83..4eb903b79 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -50,6 +50,7 @@ Other Operations - :doc:`embedding.py `: Embedding lookup operation - :doc:`all_gather_matmul.py `: All-gather operation followed by matrix multiplication - :doc:`all_reduce.py `: All-reduce operation (one-shot) +- :doc:`grpo_loss.py `: Group Relative Policy Optimization (GRPO) loss function .. toctree:: :maxdepth: 2 diff --git a/examples/grpo_loss.py b/examples/grpo_loss.py new file mode 100644 index 000000000..356705cf3 --- /dev/null +++ b/examples/grpo_loss.py @@ -0,0 +1,753 @@ +""" +Helion GRPO Loss Implementation +=============================== + +This example demonstrates a Helion kernel implementation of Group Relative Policy Optimization (GRPO) loss. +GRPO is a reinforcement learning algorithm used for training language models with human feedback. + +The implementation includes: +1. Forward pass computing GRPO loss with clipping and KL regularization +2. Backward pass for gradient computation +3. Support for completion masking and temperature scaling +4. Comparison with PyTorch reference implementation +""" + +# %% +# Imports +# ------- + +from __future__ import annotations + +import time +from typing import Callable +from typing import cast + +import torch + +import helion +from helion._testing import DEVICE +import helion.language as hl + +# %% +# Helper Functions +# ---------------- + + +def extract_selected_logits_pytorch( + logits: torch.Tensor, completion_ids: torch.Tensor, temperature: float +) -> torch.Tensor: + # Gather only the needed elements; avoid full-tensor cast and huge index grids + sel = logits.gather(dim=2, index=completion_ids.unsqueeze(-1)).squeeze(-1) + return sel.to(torch.float32) / temperature + + +def get_log_probs(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + """ + Compute log probabilities for given logits and input IDs. + + Args: + logits: Logits tensor of shape [B, L+1, V] + input_ids: Input token IDs of shape [B, L] + + Returns: + Log probabilities of shape [B, L] + """ + per_token_logps = [] + for logits_row, input_ids_row in zip( + logits, input_ids[:, -logits.size(1) :], strict=True + ): + log_probs = logits_row.log_softmax(dim=-1) + token_log_prob = torch.gather( + log_probs, dim=1, index=input_ids_row.unsqueeze(1) + ).squeeze(1) + per_token_logps.append(token_log_prob) + return torch.stack(per_token_logps) + + +def torch_grpo_loss( + logits: torch.Tensor, + old_logp: torch.Tensor | None, + ref_logp: torch.Tensor | None, + completion_ids: torch.Tensor, + advantages: torch.Tensor, + completion_mask: torch.Tensor | None, + temperature: float, + beta: float, + eps_low: float, + eps_high: float, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + """ + PyTorch reference implementation of GRPO loss. + + Args: + logits: Logits tensor of shape [B, L+1, V] + old_logp: Old log probabilities of shape [B, L] or None + ref_logp: Reference log probabilities of shape [B, L] or None + completion_ids: Completion token IDs of shape [B, L] + advantages: Advantages of shape [B] + completion_mask: Completion mask of shape [B, L] or None + temperature: Temperature scaling factor + beta: KL regularization weight + eps_low: Lower clipping bound + eps_high: Upper clipping bound + + Returns: + Tuple of (loss, kl_loss, is_clipped) + """ + assert logits.is_contiguous() and completion_ids.is_contiguous() + assert old_logp is None or old_logp.is_contiguous() + assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True + + logits = logits[:, :-1] # Remove last token + per_token_logps = get_log_probs(logits / temperature, completion_ids) + ref_per_token_logps = ref_logp + + if old_logp is None: + old_logp = per_token_logps.detach() + + coef_1 = torch.exp(per_token_logps - old_logp) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + if completion_mask is not None: + per_token_loss = per_token_loss * completion_mask + + per_token_kl = None + if beta != 0.0 and ref_per_token_logps is not None: + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) + - (ref_per_token_logps - per_token_logps) + - 1 + ) + if completion_mask is not None: + per_token_kl *= completion_mask + per_token_loss = per_token_loss + beta * per_token_kl + + is_clipped = (per_token_loss1 < per_token_loss2).float() + return per_token_loss, per_token_kl, is_clipped + + +# %% +# Helion GRPO Loss Kernels +# ------------------------ + + +@helion.kernel( + ignore_warnings=[helion.exc.TensorOperationInWrapper], autotune_effort="quick" +) +def grpo_loss_forward( + logits: torch.Tensor, # [B, L+1, V] input logits + selected_logits: torch.Tensor, # [B, L] pre-computed selected logits + old_logp: torch.Tensor | None, # [B, L] old log probabilities + ref_logp: torch.Tensor | None, # [B, L] reference log probabilities + advantages: torch.Tensor, # [B] advantages + completion_mask: torch.Tensor | None, # [B, L] completion mask + temperature: float, + beta: float, + eps_low: float, + eps_high: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Helion kernel for GRPO loss forward pass. + + Args: + logits: Logits tensor of shape [B, L+1, V] + selected_logits: Pre-computed selected logits of shape [B, L] + old_logp: Old log probabilities of shape [B, L] or None + ref_logp: Reference log probabilities of shape [B, L] or None + advantages: Advantages of shape [B] + completion_mask: Completion mask of shape [B, L] or None + temperature: Temperature scaling factor + beta: KL regularization weight + eps_low: Lower clipping bound + eps_high: Upper clipping bound + + Returns: + Tuple of (loss, kl_loss, is_clipped, lse) + """ + B, L_ADD_1, V = logits.shape + L = L_ADD_1 - 1 + + logits = logits[:, :-1, :] # [B, L, V] + + loss = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + is_clipped = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + kl_loss = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + lse = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + + for tile_b, tile_l in hl.tile([B, L]): + max_logits = hl.full([tile_b, tile_l], float("-inf"), dtype=torch.float32) + sum_exp = hl.zeros([tile_b, tile_l], dtype=torch.float32) + + for tile_v in hl.tile(V): + logits_tile = logits[tile_b, tile_l, tile_v].to(torch.float32) / temperature + new_m_i = torch.maximum(max_logits, torch.amax(logits_tile, dim=-1)) + alpha = torch.exp(max_logits - new_m_i) + sum_exp = sum_exp * alpha + torch.sum( + torch.exp(logits_tile - new_m_i[:, :, None]), dim=-1 + ) + max_logits = new_m_i + + log_sum_exp = max_logits + torch.log(sum_exp) # [tile_b, tile_l] + lse[tile_b, tile_l] = log_sum_exp + + logp = selected_logits[tile_b, tile_l] - log_sum_exp + + if old_logp is None: + old_logp_val = logp + else: + old_logp_val = old_logp[tile_b, tile_l] + + coef_1 = torch.exp(logp - old_logp_val) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + + advantage = advantages[tile_b] + + per_token_loss1 = coef_1 * advantage[:, None] + per_token_loss2 = coef_2 * advantage[:, None] + per_token_loss = -torch.minimum(per_token_loss1, per_token_loss2) + + if completion_mask is not None: + per_token_loss *= completion_mask[tile_b, tile_l] + + if beta != 0.0 and ref_logp is not None: + ref_logp_val = ref_logp[tile_b, tile_l] + kl = torch.exp(ref_logp_val - logp) - (ref_logp_val - logp) - 1 + if completion_mask is not None: + kl *= completion_mask[tile_b, tile_l] + per_token_loss += beta * kl + kl_loss[tile_b, tile_l] = kl + + loss[tile_b, tile_l] = per_token_loss + is_clipped[tile_b, tile_l] = (per_token_loss1 < per_token_loss2).float() + + return loss, kl_loss, is_clipped, lse + + +@helion.kernel( + ignore_warnings=[helion.exc.TensorOperationInWrapper], autotune_effort="quick" +) +def grpo_loss_backward( + grad_output: torch.Tensor, # [B, L] gradient from downstream + logits: torch.Tensor, # [B, L+1, V] original logits + selected_logits: torch.Tensor, # [B, L] pre-computed selected logits + completion_ids: torch.Tensor, # [B, L] completion token IDs (needed for gradients) + old_logp: torch.Tensor | None, # [B, L] old log probabilities + ref_logp: torch.Tensor | None, # [B, L] reference log probabilities + advantages: torch.Tensor, # [B] advantages + completion_mask: torch.Tensor | None, # [B, L] completion mask + lse: torch.Tensor, # [B, L] stored log-sum-exp values + temperature: float, + beta: float, + eps_low: float, + eps_high: float, +) -> torch.Tensor: + """ + Helion kernel for GRPO loss backward pass. + + Args: + grad_output: Gradient from downstream layers [B, L] + logits: Original logits tensor [B, L+1, V] + selected_logits: Pre-computed selected logits [B, L] + completion_ids: Completion token IDs [B, L] (needed for gradients) + old_logp: Old log probabilities [B, L] or None + ref_logp: Reference log probabilities [B, L] or None + advantages: Advantages [B] + completion_mask: Completion mask [B, L] or None + lse: Stored log-sum-exp values [B, L] + temperature: Temperature scaling factor + beta: KL regularization weight + eps_low: Lower clipping bound + eps_high: Upper clipping bound + + Returns: + Gradient with respect to logits [B, L+1, V] + """ + B, L_ADD_1, V = logits.shape + L = L_ADD_1 - 1 + + logits_fwd = logits[:, :-1, :] # [B, L, V] + + grad_logits = torch.zeros_like(logits) + + for tile_b, tile_l in hl.tile([B, L]): + completion_id = completion_ids[tile_b, tile_l] + + log_sum_exp = lse[tile_b, tile_l] + + logp = selected_logits[tile_b, tile_l] - log_sum_exp + + if old_logp is None: + old_logp_val = logp + else: + old_logp_val = old_logp[tile_b, tile_l] + + coef_1 = torch.exp(logp - old_logp_val) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + + advantage = advantages[tile_b] + + per_token_loss1 = coef_1 * advantage[:, None] + per_token_loss2 = coef_2 * advantage[:, None] + + mask = (per_token_loss2 >= per_token_loss1).float() + + dlogp = -per_token_loss1 * mask + + if beta != 0.0 and ref_logp is not None: + ref_logp_val = ref_logp[tile_b, tile_l] + dlogp += beta * (1 - torch.exp(ref_logp_val - logp)) + + dlogp = dlogp * grad_output[tile_b, tile_l] / temperature + + if completion_mask is not None: + mask_val = completion_mask[tile_b, tile_l] + dlogp *= mask_val + + for tile_v in hl.tile(V): + logits_tile = ( + logits_fwd[tile_b, tile_l, tile_v].to(torch.float32) / temperature + ) + probs = torch.exp(logits_tile - log_sum_exp[:, :, None]) + + v_indices = tile_v.index + sel = v_indices[None, None, :] == completion_id[:, :, None] + + grad_logits_tile = torch.where( + sel, + dlogp[:, :, None] * (1 - probs), + -dlogp[:, :, None] * probs, + ) + grad_logits[tile_b, tile_l, tile_v] = grad_logits_tile + + grad_logits[:, -1, :] = 0 + + return grad_logits + + +# %% +# GRPO Loss Function Class +# ------------------------ + + +class GrpoLossFunction(torch.autograd.Function): + """Custom autograd function for GRPO loss with forward and backward passes.""" + + @staticmethod + def forward( + ctx: object, + logits: torch.Tensor, + old_logp: torch.Tensor | None, + ref_logp: torch.Tensor | None, + completion_ids: torch.Tensor, + advantages: torch.Tensor, + completion_mask: torch.Tensor | None, + temperature: float, + beta: float, + eps_low: float, + eps_high: float, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of GRPO loss.""" + selected_logits = extract_selected_logits_pytorch( + logits[:, :-1, :], completion_ids, temperature + ) + + loss, kl_loss, is_clipped, lse = grpo_loss_forward( + logits, + selected_logits, + old_logp, + ref_logp, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + ctx.save_for_backward( # type: ignore[attr-defined] + logits, + selected_logits, + completion_ids, + old_logp, + ref_logp, + advantages, + completion_mask, + lse, + ) + ctx.temperature = temperature # type: ignore[attr-defined] + ctx.beta = beta # type: ignore[attr-defined] + ctx.eps_low = eps_low # type: ignore[attr-defined] + ctx.eps_high = eps_high # type: ignore[attr-defined] + + return loss, kl_loss, is_clipped + + @staticmethod + def backward( + ctx: object, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor | None, ...]: + """Backward pass of GRPO loss.""" + # Unpack incoming gradients (we only need the first one for 'loss') + grad_loss = grad_outputs[0] + + ( + logits, + selected_logits, + completion_ids, + old_logp, + ref_logp, + advantages, + completion_mask, + lse, + ) = ctx.saved_tensors # type: ignore[attr-defined] + + grad_logits = grpo_loss_backward( + grad_loss, + logits, + selected_logits, + completion_ids, + old_logp, + ref_logp, + advantages, + completion_mask, + lse, + ctx.temperature, # type: ignore[attr-defined] + ctx.beta, # type: ignore[attr-defined] + ctx.eps_low, # type: ignore[attr-defined] + ctx.eps_high, # type: ignore[attr-defined] + ) + + return ( + grad_logits, # d(logits) + None, # d(old_logp) + None, # d(ref_logp) + None, # d(completion_ids) + None, # d(advantages) + None, # d(completion_mask) + None, # d(temperature) + None, # d(beta) + None, # d(eps_low) + None, # d(eps_high) + ) + + +def helion_grpo_loss( + logits: torch.Tensor, + old_logp: torch.Tensor | None, + ref_logp: torch.Tensor | None, + completion_ids: torch.Tensor, + advantages: torch.Tensor, + completion_mask: torch.Tensor | None = None, + temperature: float = 0.9, + beta: float = 0.04, + eps_low: float = 0.2, + eps_high: float = 0.4, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Helion implementation of GRPO loss. + + Args: + logits: Logits tensor of shape [B, L+1, V] + old_logp: Old log probabilities of shape [B, L] or None + ref_logp: Reference log probabilities of shape [B, L] or None + completion_ids: Completion token IDs of shape [B, L] + advantages: Advantages of shape [B] + completion_mask: Completion mask of shape [B, L] or None + temperature: Temperature scaling factor + beta: KL regularization weight + eps_low: Lower clipping bound + eps_high: Upper clipping bound + + Returns: + Tuple of (loss, kl_loss, is_clipped) + """ + result = cast( + "tuple[torch.Tensor, torch.Tensor, torch.Tensor]", + GrpoLossFunction.apply( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ), + ) + loss, kl_loss, is_clipped = result + return loss, kl_loss, is_clipped + + +# %% +# Verification and Testing +# ------------------------ + + +def compare_tensors( + tensor1: torch.Tensor | None, tensor2: torch.Tensor | None, name: str = "" +) -> None: + """Compare two tensors and print statistics.""" + if tensor1 is None or tensor2 is None: + return + if any([tensor1.dtype == torch.float32, tensor2.dtype == torch.float32]): + tensor1, tensor2 = tensor1.float(), tensor2.float() + diff = (tensor1 - tensor2).abs() + diff = diff / (torch.max(tensor1.abs(), tensor2.abs()) + 1e-5) + print(f"Max difference: {diff.max().item()}, Mean difference: {diff.mean().item()}") + + +def test_grpo_loss( + B: int = 8, + L: int = 1024, + V: int = 12800, + temperature: float = 0.9, + beta: float = 0.2, + eps_low: float = 0.2, + eps_high: float = 0.4, +) -> None: + """Test GRPO loss implementation against PyTorch reference.""" + print(f"Testing GRPO Loss: B={B}, L={L}, V={V}") + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + logits1 = torch.randn( + B, L + 1, V, device=DEVICE, dtype=torch.bfloat16, requires_grad=True + ) + logits2 = logits1.clone().detach().requires_grad_(True) + logits_ref = logits1.detach().clone().float().requires_grad_(True) + + completion_ids = torch.randint(0, V - 1, (B, L), dtype=torch.int64, device=DEVICE) + completion_mask = torch.ones_like(completion_ids, dtype=torch.float32) + ref_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + old_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + advantages = torch.randn(B, device=DEVICE, dtype=torch.float32) + + print("\n=== Forward Pass Test ===") + + loss_ref, kl_ref, clipped_ref = torch_grpo_loss( + logits_ref, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + loss_helion, kl_helion, clipped_helion = helion_grpo_loss( + logits2, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + compare_tensors(loss_helion, loss_ref, "Loss") + compare_tensors(kl_helion, kl_ref, "KL Loss") + compare_tensors(clipped_helion, clipped_ref, "Is Clipped") + + print("\n=== Backward Pass Test ===") + + grad_output = torch.randn_like(loss_ref) + + loss_ref.backward(grad_output, retain_graph=True) + grad_ref = logits_ref.grad.clone() if logits_ref.grad is not None else None + + logits_ref.grad = None + + loss_helion.backward(grad_output, retain_graph=True) + grad_helion = logits2.grad.clone() if logits2.grad is not None else None + + compare_tensors(grad_helion, grad_ref, "Gradient") + + print("\n=== Test Complete ===") + + +def _cuda_sync() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def _measure_timing(run_fn: Callable[[], None], iters: int, warmup: int) -> float: + times = [] + for _ in range(warmup): + run_fn() + _cuda_sync() + for _ in range(iters): + t0 = time.perf_counter() + run_fn() + _cuda_sync() + t1 = time.perf_counter() + times.append((t1 - t0) * 1000.0) + times.sort() + mid = len(times) // 2 + return times[mid] if len(times) % 2 == 1 else 0.5 * (times[mid - 1] + times[mid]) + + +def benchmark_grpo_loss( + B: int = 8, + L: int = 1024, + V: int = 12800, + temperature: float = 0.9, + beta: float = 0.2, + eps_low: float = 0.2, + eps_high: float = 0.4, + iters: int = 50, + warmup: int = 10, +) -> None: + print( + f"Benchmarking GRPO Loss: B={B}, L={L}, V={V} (iters={iters}, warmup={warmup})" + ) + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + + logits_ref = torch.randn( + B, L + 1, V, device=DEVICE, dtype=torch.float32, requires_grad=True + ) + logits_hel = logits_ref.detach().clone().to(torch.bfloat16).requires_grad_(True) + + completion_ids = torch.randint(0, V - 1, (B, L), dtype=torch.int64, device=DEVICE) + completion_mask = torch.ones_like(completion_ids, dtype=torch.int32) + ref_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + old_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + advantages = torch.randn(B, device=DEVICE, dtype=torch.float32) + + grad_out = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + + def run_torch_fwd() -> None: + torch_grpo_loss( + logits_ref, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + def run_torch_bwd() -> None: + logits_ref.grad = None + loss_ref, _, _ = torch_grpo_loss( + logits_ref, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + loss_ref.backward(grad_out, retain_graph=False) + + def run_helion_fwd() -> None: + helion_grpo_loss( + logits_hel, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + def run_helion_bwd() -> None: + logits_hel.grad = None + loss_hel, _, _ = helion_grpo_loss( + logits_hel, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + loss_hel.backward(grad_out, retain_graph=False) + + torch_fwd_ms = _measure_timing(run_torch_fwd, iters, warmup) + torch_bwd_ms = _measure_timing(run_torch_bwd, iters, warmup) + hel_fwd_ms = _measure_timing(run_helion_fwd, iters, warmup) + hel_bwd_ms = _measure_timing(run_helion_bwd, iters, warmup) + + def speedup(a: float, b: float) -> float: + return a / b if b > 0 else float("inf") + + print("\n=== Timing (median ms) ===") + print(f"PyTorch Forward: {torch_fwd_ms:.3f} ms") + print(f"PyTorch Backward: {torch_bwd_ms:.3f} ms") + print( + f"Helion Forward: {hel_fwd_ms:.3f} ms (x{speedup(torch_fwd_ms, hel_fwd_ms):.2f} vs Torch)" + ) + print( + f"Helion Backward: {hel_bwd_ms:.3f} ms (x{speedup(torch_bwd_ms, hel_bwd_ms):.2f} vs Torch)" + ) + + tokens = B * L + print("\n=== Throughput ===") + print(f"PyTorch Fwd tokens/s: {tokens / (torch_fwd_ms / 1000.0):.1f}") + print(f"PyTorch Bwd tokens/s: {tokens / (torch_bwd_ms / 1000.0):.1f}") + print(f"Helion Fwd tokens/s: {tokens / (hel_fwd_ms / 1000.0):.1f}") + print(f"Helion Bwd tokens/s: {tokens / (hel_bwd_ms / 1000.0):.1f}") + + +# %% +# Main Function +# ------------- + + +def main() -> None: + """Main entry point for GRPO loss testing.""" + print("Helion GRPO Loss Implementation") + print("=" * 50) + + test_configs = [ + {"B": 8, "L": 2048, "V": 64000}, + # {"B": 4, "L": 2048, "V": 128000}, + # {"B": 8, "L": 4096, "V": 100000}, + ] + + for config in test_configs: + test_grpo_loss(**config) + print() + + benchmark_grpo_loss( + B=8, + L=2048, + V=64000, + temperature=0.9, + beta=0.2, + eps_low=0.2, + eps_high=0.4, + iters=50, + warmup=10, + ) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index e665d87b0..342b715d4 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2102,6 +2102,346 @@ def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, grou # src[grouped_gemm.py:N]: return out return out +--- assertExpectedJournal(TestExamples.test_grpo_loss_bwd) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_grpo_loss_backward(completion_ids, lse, selected_logits, old_logp, advantages, ref_logp, grad_output, completion_mask, logits_fwd, grad_logits, eps_low, eps_high, beta, temperature, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]): + num_blocks_0 = tl.cdiv(2, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[grpo_loss.py:N]: completion_id = completion_ids[tile_b, tile_l] + completion_id = tl.load(completion_ids + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + # src[grpo_loss.py:N]: log_sum_exp = lse[tile_b, tile_l] + log_sum_exp = tl.load(lse + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + # src[grpo_loss.py:N]: logp = selected_logits[tile_b, tile_l] - log_sum_exp + load_2 = tl.load(selected_logits + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + v_0 = load_2 - log_sum_exp + # src[grpo_loss.py:N]: old_logp_val = old_logp[tile_b, tile_l] + old_logp_val = tl.load(old_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + # src[grpo_loss.py:N]: coef_1 = torch.exp(logp - old_logp_val) + v_1 = v_0 - old_logp_val + v_2 = libdevice.exp(v_1) + # src[grpo_loss.py:N]: coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + sub_2 = 1.0 + -1 * eps_low + add = 1.0 + eps_high + v_3 = triton_helpers.maximum(v_2, sub_2) + v_4 = triton_helpers.minimum(v_3, add) + # src[grpo_loss.py:N]: advantage = advantages[tile_b] + advantage = tl.load(advantages + indices_0 * 1, None) + # src[grpo_loss.py:N]: per_token_loss1 = coef_1 * advantage[:, None] + subscript = advantage[:, None] + v_5 = v_2 * subscript + # src[grpo_loss.py:N]: per_token_loss2 = coef_2 * advantage[:, None] + subscript_1 = advantage[:, None] + v_6 = v_4 * subscript_1 + # src[grpo_loss.py:N]: mask = (per_token_loss2 >= per_token_loss1).float() + v_7 = v_6 >= v_5 + v_8 = tl.cast(v_7, tl.float32) + # src[grpo_loss.py:N]: dlogp = -per_token_loss1 * mask + v_9 = -v_5 + v_10 = v_9 * v_8 + # src[grpo_loss.py:N]: if beta != 0.0 and ref_logp is not None: + ne = beta != 0.0 + _and = ne and True + # src[grpo_loss.py:N]: if beta != 0.0 and ref_logp is not None: + # src[grpo_loss.py:N]: ref_logp_val = ref_logp[tile_b, tile_l] + # src[grpo_loss.py:N]: dlogp += beta * (1 - torch.exp(ref_logp_val - logp)) + if _and: + v_0_copy = v_0 + v_10_copy = v_10 + v_0_copy_0 = v_0_copy + v_10_copy_0 = v_10_copy + # src[grpo_loss.py:N]: ref_logp_val = ref_logp[tile_b, tile_l] + ref_logp_val = tl.load(ref_logp + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + # src[grpo_loss.py:N]: dlogp += beta * (1 - torch.exp(ref_logp_val - logp)) + v_11 = ref_logp_val - v_0_copy_0 + v_12 = libdevice.exp(v_11) + v_13 = 1.0 + v_14 = v_13 - v_12 + v_15 = v_14 * beta + v_10 = v_10_copy_0 + v_15 + # src[grpo_loss.py:N]: dlogp = dlogp * grad_output[tile_b, tile_l] / temperature + load_5 = tl.load(grad_output + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + v_17 = v_10 * load_5 + v_18 = v_17 / temperature + # src[grpo_loss.py:N]: mask_val = completion_mask[tile_b, tile_l] + mask_val = tl.load(completion_mask + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + # src[grpo_loss.py:N]: dlogp *= mask_val + v_19 = v_18 * mask_val + # src[grpo_loss.py:N]: for tile_v in hl.tile(V): + # src[grpo_loss.py:N]: logits_tile = ( + # src[grpo_loss.py:N]: logits_fwd[tile_b, tile_l, tile_v].to(torch.float32) / temperature + # src[grpo_loss.py:N-N]: ... + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + log_sum_exp_copy = log_sum_exp + completion_id_copy = completion_id + v_19_copy = v_19 + log_sum_exp_copy_0 = log_sum_exp_copy + completion_id_copy_0 = completion_id_copy + v_19_copy_0 = v_19_copy + # src[grpo_loss.py:N]: logits_fwd[tile_b, tile_l, tile_v].to(torch.float32) / temperature + load = tl.load(logits_fwd + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), None) + v_20 = tl.cast(load, tl.float32) + v_21 = v_20 / temperature + # src[grpo_loss.py:N]: probs = torch.exp(logits_tile - log_sum_exp[:, :, None]) + subscript_2 = log_sum_exp_copy_0[:, :, None] + v_22 = v_21 - subscript_2 + v_23 = libdevice.exp(v_22) + # src[grpo_loss.py:N]: sel = v_indices[None, None, :] == completion_id[:, :, None] + subscript_3 = indices_2[None, None, :] + subscript_4 = completion_id_copy_0[:, :, None] + v_24 = tl.cast(subscript_3, tl.int64) + v_25 = v_24 == subscript_4 + # src[grpo_loss.py:N]: dlogp[:, :, None] * (1 - probs), + subscript_5 = v_19_copy_0[:, :, None] + v_26 = 1.0 + v_27 = v_26 - v_23 + v_28 = subscript_5 * v_27 + # src[grpo_loss.py:N]: -dlogp[:, :, None] * probs, + subscript_6 = v_19_copy_0[:, :, None] + v_29 = -subscript_6 + v_30 = v_29 * v_23 + # src[grpo_loss.py:N]: grad_logits_tile = torch.where( + # src[grpo_loss.py:N]: sel, + # src[grpo_loss.py:N]: dlogp[:, :, None] * (1 - probs), + # src[grpo_loss.py:N-N]: ... + v_31 = tl.where(v_25, v_28, v_30) + # src[grpo_loss.py:N]: grad_logits[tile_b, tile_l, tile_v] = grad_logits_tile + v_32 = tl.cast(v_31, tl.bfloat16) + tl.store(grad_logits + (indices_0[:, None, None] * 8320 + indices_1[None, :, None] * 128 + indices_2[None, None, :] * 1), v_32, None) + +def grpo_loss_backward(grad_output: torch.Tensor, logits: torch.Tensor, selected_logits: torch.Tensor, completion_ids: torch.Tensor, old_logp: torch.Tensor | None, ref_logp: torch.Tensor | None, advantages: torch.Tensor, completion_mask: torch.Tensor | None, lse: torch.Tensor, temperature: float, beta: float, eps_low: float, eps_high: float, *, _launcher=_default_launcher): + """ + Helion kernel for GRPO loss backward pass. + + Args: + grad_output: Gradient from downstream layers [B, L] + logits: Original logits tensor [B, L+1, V] + selected_logits: Pre-computed selected logits [B, L] + completion_ids: Completion token IDs [B, L] (needed for gradients) + old_logp: Old log probabilities [B, L] or None + ref_logp: Reference log probabilities [B, L] or None + advantages: Advantages [B] + completion_mask: Completion mask [B, L] or None + lse: Stored log-sum-exp values [B, L] + temperature: Temperature scaling factor + beta: KL regularization weight + eps_low: Lower clipping bound + eps_high: Upper clipping bound + + Returns: + Gradient with respect to logits [B, L+1, V] + """ + # src[grpo_loss.py:N]: B, L_ADD_1, V = logits.shape + B, L_ADD_1, V = logits.shape + # src[grpo_loss.py:N]: logits_fwd = logits[:, :-1, :] # [B, L, V] + logits_fwd = logits[:, :-1, :] + # src[grpo_loss.py:N]: grad_logits = torch.zeros_like(logits) + grad_logits = torch.zeros_like(logits) + # src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]): + _BLOCK_SIZE_0 = 2 + _BLOCK_SIZE_1 = 16 + # src[grpo_loss.py:N]: for tile_v in hl.tile(V): + # src[grpo_loss.py:N]: logits_tile = ( + # src[grpo_loss.py:N]: logits_fwd[tile_b, tile_l, tile_v].to(torch.float32) / temperature + # src[grpo_loss.py:N-N]: ... + _BLOCK_SIZE_2 = 16 + # src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]): + # src[grpo_loss.py:N]: completion_id = completion_ids[tile_b, tile_l] + # src[grpo_loss.py:N-N]: ... + _launcher(_helion_grpo_loss_backward, (triton.cdiv(2, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), completion_ids, lse, selected_logits, old_logp, advantages, ref_logp, grad_output, completion_mask, logits_fwd, grad_logits, eps_low, eps_high, beta, temperature, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[grpo_loss.py:N]: grad_logits[:, -1, :] = 0 + grad_logits[:, -1, :] = 0 + # src[grpo_loss.py:N]: return grad_logits + return grad_logits + +--- assertExpectedJournal(TestExamples.test_grpo_loss_fwd) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_grpo_loss_forward(logits, lse, selected_logits, old_logp, advantages, completion_mask, ref_logp, kl_loss, loss, is_clipped, temperature, eps_low, eps_high, beta, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]): + num_blocks_0 = tl.cdiv(4, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[grpo_loss.py:N]: max_logits = hl.full([tile_b, tile_l], float("-inf"), dtype=torch.float32) + max_logits = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], float('-inf'), tl.float32) + # src[grpo_loss.py:N]: sum_exp = hl.zeros([tile_b, tile_l], dtype=torch.float32) + sum_exp = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[grpo_loss.py:N]: for tile_v in hl.tile(V): + # src[grpo_loss.py:N]: logits_tile = logits[tile_b, tile_l, tile_v].to(torch.float32) / temperature + # src[grpo_loss.py:N]: new_m_i = torch.maximum(max_logits, torch.amax(logits_tile, dim=-1)) + # src[grpo_loss.py:N-N]: ... + for offset_2 in tl.range(0, 2048, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + max_logits_copy = max_logits + sum_exp_copy = sum_exp + max_logits_copy_0 = max_logits_copy + sum_exp_copy_0 = sum_exp_copy + # src[grpo_loss.py:N]: logits_tile = logits[tile_b, tile_l, tile_v].to(torch.float32) / temperature + load = tl.load(logits + (indices_0[:, None, None] * 1050624 + indices_1[None, :, None] * 2048 + indices_2[None, None, :] * 1), None) + v_0 = tl.cast(load, tl.float32) + v_1 = v_0 / temperature + # src[grpo_loss.py:N]: new_m_i = torch.maximum(max_logits, torch.amax(logits_tile, dim=-1)) + amax = tl.cast(tl.max(v_1, 2), tl.float32) + v_2 = triton_helpers.maximum(max_logits_copy_0, amax) + # src[grpo_loss.py:N]: alpha = torch.exp(max_logits - new_m_i) + v_3 = max_logits_copy_0 - v_2 + v_4 = libdevice.exp(v_3) + # src[grpo_loss.py:N]: sum_exp = sum_exp * alpha + torch.sum( + v_5 = sum_exp_copy_0 * v_4 + # src[grpo_loss.py:N]: torch.exp(logits_tile - new_m_i[:, :, None]), dim=-1 + subscript = v_2[:, :, None] + v_6 = v_1 - subscript + v_7 = libdevice.exp(v_6) + # src[grpo_loss.py:N]: sum_exp = sum_exp * alpha + torch.sum( + # src[grpo_loss.py:N]: torch.exp(logits_tile - new_m_i[:, :, None]), dim=-1 + # src[grpo_loss.py:N]: ) + sum_1 = tl.cast(tl.sum(v_7, 2), tl.float32) + sum_exp = v_5 + sum_1 + # src[grpo_loss.py:N]: max_logits = new_m_i + max_logits = v_2 + # src[grpo_loss.py:N]: log_sum_exp = max_logits + torch.log(sum_exp) # [tile_b, tile_l] + v_9 = tl_math.log(sum_exp) + v_10 = max_logits + v_9 + # src[grpo_loss.py:N]: lse[tile_b, tile_l] = log_sum_exp + tl.store(lse + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_10, None) + # src[grpo_loss.py:N]: logp = selected_logits[tile_b, tile_l] - log_sum_exp + load_1 = tl.load(selected_logits + (indices_0[:, None] * 512 + indices_1[None, :] * 1), None) + v_11 = load_1 - v_10 + # src[grpo_loss.py:N]: old_logp_val = old_logp[tile_b, tile_l] + old_logp_val = tl.load(old_logp + (indices_0[:, None] * 512 + indices_1[None, :] * 1), None) + # src[grpo_loss.py:N]: coef_1 = torch.exp(logp - old_logp_val) + v_12 = v_11 - old_logp_val + v_13 = libdevice.exp(v_12) + # src[grpo_loss.py:N]: coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + sub_2 = 1.0 + -1 * eps_low + add_1 = 1.0 + eps_high + v_14 = triton_helpers.maximum(v_13, sub_2) + v_15 = triton_helpers.minimum(v_14, add_1) + # src[grpo_loss.py:N]: advantage = advantages[tile_b] + advantage = tl.load(advantages + indices_0 * 1, None) + # src[grpo_loss.py:N]: per_token_loss1 = coef_1 * advantage[:, None] + subscript_1 = advantage[:, None] + v_16 = v_13 * subscript_1 + # src[grpo_loss.py:N]: per_token_loss2 = coef_2 * advantage[:, None] + subscript_2 = advantage[:, None] + v_17 = v_15 * subscript_2 + # src[grpo_loss.py:N]: per_token_loss = -torch.minimum(per_token_loss1, per_token_loss2) + v_18 = triton_helpers.minimum(v_16, v_17) + v_19 = -v_18 + # src[grpo_loss.py:N]: per_token_loss *= completion_mask[tile_b, tile_l] + load_3 = tl.load(completion_mask + (indices_0[:, None] * 512 + indices_1[None, :] * 1), None) + v_20 = v_19 * load_3 + # src[grpo_loss.py:N]: if beta != 0.0 and ref_logp is not None: + ne = beta != 0.0 + _and = ne and True + # src[grpo_loss.py:N]: if beta != 0.0 and ref_logp is not None: + # src[grpo_loss.py:N]: ref_logp_val = ref_logp[tile_b, tile_l] + # src[grpo_loss.py:N]: kl = torch.exp(ref_logp_val - logp) - (ref_logp_val - logp) - 1 + # src[grpo_loss.py:N-N]: ... + if _and: + v_11_copy = v_11 + v_20_copy = v_20 + v_11_copy_0 = v_11_copy + v_20_copy_0 = v_20_copy + # src[grpo_loss.py:N]: ref_logp_val = ref_logp[tile_b, tile_l] + ref_logp_val = tl.load(ref_logp + (indices_0[:, None] * 512 + indices_1[None, :] * 1), None) + # src[grpo_loss.py:N]: kl = torch.exp(ref_logp_val - logp) - (ref_logp_val - logp) - 1 + v_21 = ref_logp_val - v_11_copy_0 + v_22 = libdevice.exp(v_21) + v_23 = ref_logp_val - v_11_copy_0 + v_24 = v_22 - v_23 + v_25 = 1.0 + v_26 = v_24 - v_25 + # src[grpo_loss.py:N]: kl *= completion_mask[tile_b, tile_l] + load_2 = tl.load(completion_mask + (indices_0[:, None] * 512 + indices_1[None, :] * 1), None) + v_27 = v_26 * load_2 + # src[grpo_loss.py:N]: per_token_loss += beta * kl + v_28 = v_27 * beta + v_20 = v_20_copy_0 + v_28 + # src[grpo_loss.py:N]: kl_loss[tile_b, tile_l] = kl + tl.store(kl_loss + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_27, None) + # src[grpo_loss.py:N]: loss[tile_b, tile_l] = per_token_loss + tl.store(loss + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_20, None) + # src[grpo_loss.py:N]: is_clipped[tile_b, tile_l] = (per_token_loss1 < per_token_loss2).float() + v_30 = v_16 < v_17 + v_31 = tl.cast(v_30, tl.float32) + tl.store(is_clipped + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_31, None) + +def grpo_loss_forward(logits: torch.Tensor, selected_logits: torch.Tensor, old_logp: torch.Tensor | None, ref_logp: torch.Tensor | None, advantages: torch.Tensor, completion_mask: torch.Tensor | None, temperature: float, beta: float, eps_low: float, eps_high: float, *, _launcher=_default_launcher): + """ + Helion kernel for GRPO loss forward pass. + + Args: + logits: Logits tensor of shape [B, L+1, V] + selected_logits: Pre-computed selected logits of shape [B, L] + old_logp: Old log probabilities of shape [B, L] or None + ref_logp: Reference log probabilities of shape [B, L] or None + advantages: Advantages of shape [B] + completion_mask: Completion mask of shape [B, L] or None + temperature: Temperature scaling factor + beta: KL regularization weight + eps_low: Lower clipping bound + eps_high: Upper clipping bound + + Returns: + Tuple of (loss, kl_loss, is_clipped, lse) + """ + # src[grpo_loss.py:N]: B, L_ADD_1, V = logits.shape + B, L_ADD_1, V = logits.shape + # src[grpo_loss.py:N]: L = L_ADD_1 - 1 + L = L_ADD_1 - 1 + # src[grpo_loss.py:N]: logits = logits[:, :-1, :] # [B, L, V] + logits = logits[:, :-1, :] + # src[grpo_loss.py:N]: loss = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + loss = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + # src[grpo_loss.py:N]: is_clipped = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + is_clipped = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + # src[grpo_loss.py:N]: kl_loss = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + kl_loss = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + # src[grpo_loss.py:N]: lse = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + lse = torch.zeros([B, L], dtype=torch.float32, device=logits.device) + # src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]): + _BLOCK_SIZE_0 = 4 + _BLOCK_SIZE_1 = 16 + # src[grpo_loss.py:N]: for tile_v in hl.tile(V): + # src[grpo_loss.py:N]: logits_tile = logits[tile_b, tile_l, tile_v].to(torch.float32) / temperature + # src[grpo_loss.py:N]: new_m_i = torch.maximum(max_logits, torch.amax(logits_tile, dim=-1)) + # src[grpo_loss.py:N-N]: ... + _BLOCK_SIZE_2 = 16 + # src[grpo_loss.py:N]: for tile_b, tile_l in hl.tile([B, L]): + # src[grpo_loss.py:N]: max_logits = hl.full([tile_b, tile_l], float("-inf"), dtype=torch.float32) + # src[grpo_loss.py:N]: sum_exp = hl.zeros([tile_b, tile_l], dtype=torch.float32) + # src[grpo_loss.py:N-N]: ... + _launcher(_helion_grpo_loss_forward, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), logits, lse, selected_logits, old_logp, advantages, completion_mask, ref_logp, kl_loss, loss, is_clipped, temperature, eps_low, eps_high, beta, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[grpo_loss.py:N]: return loss, kl_loss, is_clipped, lse + return (loss, kl_loss, is_clipped, lse) + --- assertExpectedJournal(TestExamples.test_int4_gemm) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 61eb0a07d..aecf53dd7 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -772,7 +772,6 @@ def test_jagged_mean(self): ) args = (x_data, x_offsets, feature_counts, M) - # Import and use the reference implementation mod = import_path(EXAMPLES_DIR / "jagged_mean.py") expected = mod.reference_jagged_mean_kernel_pytorch( x_data, x_offsets, feature_counts, M @@ -1636,6 +1635,157 @@ def test_squeeze_and_excitation_net_bwd_db(self): ) ) + def test_grpo_loss_fwd(self): + """Test forward pass for GRPO loss.""" + B, L, V = 4, 512, 2048 + temperature = 0.9 + beta = 0.04 + eps_low = 0.2 + eps_high = 0.4 + + torch.manual_seed(42) + logits = torch.randn([B, L + 1, V], device=DEVICE, dtype=torch.bfloat16) + completion_ids = torch.randint(0, V, (B, L), device=DEVICE, dtype=torch.int64) + old_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + ref_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + advantages = torch.randn(B, device=DEVICE, dtype=torch.float32) + completion_mask = torch.ones(B, L, device=DEVICE, dtype=torch.float32) + + from examples.grpo_loss import extract_selected_logits_pytorch + + selected_logits = extract_selected_logits_pytorch( + logits[:, :-1, :], completion_ids, temperature + ) + + from examples.grpo_loss import torch_grpo_loss + + expected_loss, expected_kl, expected_clipped = torch_grpo_loss( + logits.float(), + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + args = ( + logits, + selected_logits, + old_logp, + ref_logp, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + # grpo_loss_forward returns (loss, kl_loss, is_clipped, lse) + # We only check loss, kl_loss, is_clipped (lse is None in expected) + expected = (expected_loss, expected_kl, expected_clipped, None) + + self.assertExpectedJournal( + check_example( + "grpo_loss", + args, + expected, + fn_name="grpo_loss_forward", + rtol=1e-2, + atol=1e-1, + ) + ) + + def test_grpo_loss_bwd(self): + """Test backward pass for GRPO loss.""" + B, L, V = 2, 64, 128 + temperature = 0.9 + beta = 0.04 + eps_low = 0.2 + eps_high = 0.4 + + torch.manual_seed(42) + logits = torch.randn( + [B, L + 1, V], device=DEVICE, dtype=torch.bfloat16, requires_grad=True + ) + completion_ids = torch.randint(0, V, (B, L), device=DEVICE, dtype=torch.int64) + old_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + ref_logp = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + advantages = torch.randn(B, device=DEVICE, dtype=torch.float32) + completion_mask = torch.ones(B, L, device=DEVICE, dtype=torch.float32) + + # Pre-compute selected logits and run forward pass to get lse + from examples.grpo_loss import extract_selected_logits_pytorch + from examples.grpo_loss import grpo_loss_forward + + selected_logits = extract_selected_logits_pytorch( + logits[:, :-1, :], completion_ids, temperature + ) + + _, _, _, lse = grpo_loss_forward( + logits, + selected_logits, + old_logp, + ref_logp, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + grad_output = torch.randn(B, L, device=DEVICE, dtype=torch.float32) + + logits_torch = logits.detach().clone().float().requires_grad_(True) + from examples.grpo_loss import torch_grpo_loss + + loss_torch, _, _ = torch_grpo_loss( + logits_torch, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + loss_torch.backward(grad_output) + expected_grad = logits_torch.grad + + args = ( + grad_output, + logits, + selected_logits, + completion_ids, + old_logp, + ref_logp, + advantages, + completion_mask, + lse, + temperature, + beta, + eps_low, + eps_high, + ) + + self.assertExpectedJournal( + check_example( + "grpo_loss", + args, + expected_grad, + fn_name="grpo_loss_backward", + rtol=1e-2, + atol=1e-1, + ) + ) + if __name__ == "__main__": unittest.main()