From 9ff1f7e80b10ce6cd2788c31ab0803b5fdf86dda Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 3 Oct 2025 12:05:37 -0700 Subject: [PATCH 1/2] add KL divergence backward helion kernel --- examples/kl_div.py | 149 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 131 insertions(+), 18 deletions(-) diff --git a/examples/kl_div.py b/examples/kl_div.py index 5ed884da5..57794df97 100644 --- a/examples/kl_div.py +++ b/examples/kl_div.py @@ -23,7 +23,9 @@ # ------- from __future__ import annotations +import math from typing import TYPE_CHECKING +from typing import Any import torch from torch import Tensor @@ -117,6 +119,106 @@ def kl_div_forward( return final_loss +@helion.kernel +def kl_div_backward( + grad_out: Tensor, + y_pred: Tensor, # input predictions in log-space, shape (BT, V) + y_true: Tensor, # target values, shape (BT, V) + log_target: hl.constexpr, + reduction: hl.constexpr, + eps: hl.constexpr, + compute_y_true_grad: hl.constexpr, +) -> tuple[Tensor, Tensor | None]: + BT, V = y_pred.shape + assert y_true.shape == y_pred.shape, ( + f"Shape mismatch: {y_true.shape} != {y_pred.shape}" + ) + + grad_y_pred = torch.empty_like(y_pred) + if compute_y_true_grad: + grad_y_true = torch.empty_like(y_true) + else: + grad_y_true = None + + if reduction == "none": + grad_out_expanded = grad_out + else: + grad_out_expanded = grad_out.expand(y_true.shape) + + log_eps = math.log(eps) + for tile_bt in hl.tile(BT): + for tile_v in hl.tile(V): + grad_out_val = grad_out_expanded[tile_bt, tile_v] + y_true_val = y_true[tile_bt, tile_v] + + if log_target: + y_true_exp = torch.exp(y_true_val) + + if reduction == "batchmean": + div = BT + elif reduction == "mean": + div = BT * V + else: # reduction == "sum" or "none" + div = 1.0 + + if log_target: + grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_exp / div # type: ignore + else: + grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_val / div + + if compute_y_true_grad: + y_pred_val = y_pred[tile_bt, tile_v] + if log_target: + tmp = y_true_exp * (y_true_val - y_pred_val + 1) # type: ignore + else: + lt_eps = log_eps - y_pred_val + gt_eps = torch.log(y_true_val) - y_pred_val + 1 + tmp = torch.where(y_true_val < eps, lt_eps, gt_eps) + + grad_y_true[tile_bt, tile_v] = grad_out_val * tmp / div # type: ignore[index] + + return grad_y_pred, grad_y_true + + +class KLDivFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, # noqa: ANN401 + y_pred: Tensor, # input predictions in log-space, shape (BT, V) + y_true: Tensor, # target values, shape (BT, V) + log_target: bool, + reduction: str, + eps: float, + ) -> Tensor: + """Forward pass for KL divergence.""" + loss = kl_div_forward(y_pred, y_true, log_target, reduction, eps) + ctx.save_for_backward(y_pred, y_true) # type: ignore[arg-type] + ctx.log_target = log_target + ctx.reduction = reduction + ctx.eps = eps + return loss + + @staticmethod + def backward( # type: ignore[override] + ctx: Any, # noqa: ANN401 + grad_out: Tensor, + ) -> tuple[Tensor, Tensor | None, None, None, None]: + """Backward pass for KL divergence.""" + y_pred, y_true = ctx.saved_tensors # type: ignore[attr-defined] + + grad_y_pred, grad_y_true = kl_div_backward( + grad_out, + y_pred, + y_true, + ctx.log_target, + ctx.reduction, + ctx.eps, + y_true.requires_grad, + ) + + return grad_y_pred, grad_y_true, None, None, None + + # %% # KL Divergence Loss Module # ------------------------- @@ -154,7 +256,7 @@ def forward(self, input_tensor: Tensor, target_tensor: Tensor) -> Tensor: Returns: KL divergence loss """ - return kl_div_forward( + return KLDivFunction.apply( # type: ignore[no-any-return] input_tensor, target_tensor, self.log_target, self.reduction, self.eps ) @@ -181,16 +283,26 @@ def check_kl_div_kernel( log_target: Whether target is in log-space eps: Small value for numerical stability """ - # Create test tensors following tritonbench pattern - input_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( - dim=-1 - ) - target_tensor = torch.randn(B * T, V, device="cuda").softmax(dim=-1) - - # Test forward pass + # Create test tensors following tritonbench pattern + def create_inputs() -> tuple[Tensor, Tensor]: + input_tensor = torch.randn( + B * T, V, requires_grad=True, device="cuda" + ).log_softmax(dim=-1) + input_tensor.retain_grad() + + target_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda") + if log_target: + target_tensor = target_tensor.log_softmax(dim=-1) + else: + target_tensor = target_tensor.softmax(dim=-1) + target_tensor.retain_grad() + + return input_tensor, target_tensor + + # Test forward + backward pass helion_kl = HelionKLDivLoss(reduction=reduction, log_target=log_target, eps=eps) - torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=log_target).to( + torch_kl_div = torch.nn.KLDivLoss(reduction=reduction, log_target=log_target).to( "cuda" ) @@ -200,7 +312,8 @@ def helion_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor: def baseline_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor: return torch_kl_div(input_tensor, target_tensor) - run_example(helion_wrapper, baseline_wrapper, (input_tensor, target_tensor)) + run_example(helion_wrapper, baseline_wrapper, create_inputs()) + run_example(helion_wrapper, baseline_wrapper, create_inputs(), bwd=True) # %% @@ -240,17 +353,17 @@ def main() -> None: print("Testing KL divergence kernel...") B = 8 T = 512 - reduction = "batchmean" - log_target = False eps = 1e-10 # Test with vocabulary sizes from tritonbench (2^12 to 2^17) - for V in [2**i for i in range(12, 18)]: - print( - f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}" - ) - check_kl_div_kernel(B, T, V, reduction, log_target, eps) - print("✓ KL Div passed") + for log_target in (True, False): + for reduction in ("batchmean", "mean", "sum"): + for V in [2**i for i in range(12, 17)]: + print( + f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}" + ) + check_kl_div_kernel(B, T, V, reduction, log_target, eps) + print("✓ KL Div passed") # %% From 5a35a0ff8ee4af18e17bd7738a064e1ac6942bc9 Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 8 Oct 2025 15:21:54 -0700 Subject: [PATCH 2/2] add tests --- benchmarks/run.py | 14 +++++++++ examples/kl_div.py | 12 ++++---- test/test_examples.expected | 58 +++++++++++++++++++++++++++++++++++++ test/test_examples.py | 48 +++++++++++++++++++++++++++++- 4 files changed, 125 insertions(+), 7 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index b2db48034..705cbc737 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -131,6 +131,11 @@ class RunResult: "examples.kl_div", "kl_div_tritonbench", ), + "kl_div-bwd": ( + "tritonbench.operators.kl_div.operator", + "examples.kl_div", + "kl_div_tritonbench", + ), "ragged_attention": ( "tritonbench.operators.ragged_attention.operator", "examples.jagged_hstu_attn", @@ -410,6 +415,15 @@ class RunResult: "helion_kl_div_tritonbench-speedup": "helion_speedup", "helion_kl_div_tritonbench-accuracy": "helion_accuracy", }, + "kl_div-bwd": { + "torch_kl_div": "baseline", + "liger_kl_div-speedup": "triton_speedup", + "liger_kl_div-accuracy": "triton_accuracy", + "torch_compile_kl_div-speedup": "torch_compile_speedup", + "torch_compile_kl_div-accuracy": "torch_compile_accuracy", + "helion_kl_div_tritonbench-speedup": "helion_speedup", + "helion_kl_div_tritonbench-accuracy": "helion_accuracy", + }, "gather_gemv": { "eager_gather_gemv": "baseline", "triton_gather_gemv-speedup": "triton_speedup", diff --git a/examples/kl_div.py b/examples/kl_div.py index 57794df97..dd78daf95 100644 --- a/examples/kl_div.py +++ b/examples/kl_div.py @@ -124,10 +124,10 @@ def kl_div_backward( grad_out: Tensor, y_pred: Tensor, # input predictions in log-space, shape (BT, V) y_true: Tensor, # target values, shape (BT, V) - log_target: hl.constexpr, - reduction: hl.constexpr, - eps: hl.constexpr, - compute_y_true_grad: hl.constexpr, + log_target: hl.constexpr = False, # type: ignore[arg-type] + reduction: hl.constexpr = "batchmean", # type: ignore[arg-type] + eps: hl.constexpr = 1e-10, # type: ignore[arg-type] + compute_y_true_grad: hl.constexpr = True, # type: ignore[arg-type] ) -> tuple[Tensor, Tensor | None]: BT, V = y_pred.shape assert y_true.shape == y_pred.shape, ( @@ -162,14 +162,14 @@ def kl_div_backward( div = 1.0 if log_target: - grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_exp / div # type: ignore + grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_exp / div # type: ignore[possibly-undefined] else: grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_val / div if compute_y_true_grad: y_pred_val = y_pred[tile_bt, tile_v] if log_target: - tmp = y_true_exp * (y_true_val - y_pred_val + 1) # type: ignore + tmp = y_true_exp * (y_true_val - y_pred_val + 1) # type: ignore[possibly-undefined] else: lt_eps = log_eps - y_pred_val gt_eps = torch.log(y_true_val) - y_pred_val + 1 diff --git a/test/test_examples.expected b/test/test_examples.expected index 2c004def2..df78d0732 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2352,6 +2352,64 @@ def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduc final_loss = loss return final_loss +--- assertExpectedJournal(TestExamples.test_kl_div_bwd) +from __future__ import annotations + +import torch +import helion.language as hl +import triton +import triton.language as tl +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kl_div_backward(grad_out_expanded, y_true, grad_y_pred, y_pred, grad_y_true, grad_out_expanded_stride_0, grad_out_expanded_stride_1, grad_y_pred_stride_0, grad_y_pred_stride_1, grad_y_true_stride_0, grad_y_true_stride_1, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < BT + for offset_1 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < V + grad_out_val = tl.load(grad_out_expanded + (indices_0[:, None] * grad_out_expanded_stride_0 + indices_1[None, :] * grad_out_expanded_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + y_true_val = tl.load(y_true + (indices_0[:, None] * y_true_stride_0 + indices_1[None, :] * y_true_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = -grad_out_val + v_1 = v_0 * y_true_val + v_2 = tl.cast(BT, tl.float32) + v_3 = v_1 / v_2 + tl.store(grad_y_pred + (indices_0[:, None] * grad_y_pred_stride_0 + indices_1[None, :] * grad_y_pred_stride_1), v_3, mask_0[:, None] & mask_1[None, :]) + y_pred_val = tl.load(y_pred + (indices_0[:, None] * y_pred_stride_0 + indices_1[None, :] * y_pred_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_4 = -23.025850929940457 + v_5 = v_4 - y_pred_val + v_6 = tl_math.log(y_true_val) + v_7 = v_6 - y_pred_val + v_8 = 1.0 + v_9 = v_7 + v_8 + v_10 = 1e-10 + v_11 = y_true_val < v_10 + v_12 = tl.where(v_11, v_5, v_9) + v_13 = grad_out_val * v_12 + v_14 = tl.cast(BT, tl.float32) + v_15 = v_13 / v_14 + tl.store(grad_y_true + (indices_0[:, None] * grad_y_true_stride_0 + indices_1[None, :] * grad_y_true_stride_1), v_15, mask_0[:, None] & mask_1[None, :]) + +def kl_div_backward(grad_out: Tensor, y_pred: Tensor, y_true: Tensor, log_target: hl.constexpr=False, reduction: hl.constexpr='batchmean', eps: hl.constexpr=1e-10, compute_y_true_grad: hl.constexpr=True, *, _launcher=_default_launcher): + BT, V = y_pred.shape + assert y_true.shape == y_pred.shape, f'Shape mismatch: {y_true.shape} != {y_pred.shape}' + grad_y_pred = torch.empty_like(y_pred) + if True: + grad_y_true = torch.empty_like(y_true) + else: + grad_y_true = None + if 'batchmean' == 'none': + grad_out_expanded = grad_out + else: + grad_out_expanded = grad_out.expand(y_true.shape) + _BLOCK_SIZE_0 = 64 + _BLOCK_SIZE_1 = 64 + _launcher(_helion_kl_div_backward, (triton.cdiv(BT, _BLOCK_SIZE_0),), grad_out_expanded, y_true, grad_y_pred, y_pred, grad_y_true, grad_out_expanded.stride(0), grad_out_expanded.stride(1), grad_y_pred.stride(0), grad_y_pred.stride(1), grad_y_true.stride(0), grad_y_true.stride(1), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return (grad_y_pred, grad_y_true) + --- assertExpectedJournal(TestExamples.test_layernorm_bwd_dwdb) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index b0a7df3d1..94a3eb548 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1122,7 +1122,7 @@ def test_jsd(self): ) ) - def test_kl_div(self): + def test_kl_div_fwd(self): args = ( torch.randn( [8 * 512, 4096], device=DEVICE, dtype=torch.float32 @@ -1146,6 +1146,52 @@ def test_kl_div(self): ) ) + def test_kl_div_bwd(self): + y_pred = torch.randn( + [8 * 512, 4096], device=DEVICE, dtype=torch.float32 + ).log_softmax(dim=-1) + y_true = torch.randn( + [8 * 512, 4096], device=DEVICE, dtype=torch.float32 + ).softmax(dim=-1) + grad_out = torch.randn([], device=DEVICE, dtype=torch.float32) + log_target = False + reduction = "batchmean" + eps = 1e-10 + + # Compute forward pass to get rms + from examples.kl_div import kl_div_forward + + # Create configured kernel with explicit config + config = helion.Config(block_size=32, num_warps=4, num_stages=3) + configured_kernel = helion.kernel(kl_div_forward.fn, config=config) + _ = configured_kernel(y_pred, y_true, log_target, reduction, eps) + + # Compute expected gradients with PyTorch + y_pred_torch = y_pred.detach().clone().requires_grad_(True) + y_true_torch = y_true.detach().clone().requires_grad_(True) + loss_torch = torch.nn.functional.kl_div( + y_pred_torch, y_true_torch, log_target=log_target, reduction=reduction + ) + loss_torch.backward(grad_out) + + args = ( + grad_out, + y_pred, + y_true, + ) + + self.assertExpectedJournal( + check_example( + "kl_div", + args, + (y_pred_torch.grad, y_true_torch.grad), + fn_name="kl_div_backward", + block_sizes=[64, 64], + num_warps=4, + num_stages=3, + ) + ) + def test_gather_gemv(self): args = ( torch.randn([8, 1024, 1024], device=DEVICE, dtype=torch.float32),