From 01e8c958a2b896fc0aa089601daca175c8d346b5 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Tue, 30 Sep 2025 15:33:09 -0700 Subject: [PATCH] Faster Helion JSD stack-info: PR: https://github.com/pytorch/helion/pull/733, branch: PaulZhang12/stack/9 --- examples/jsd.py | 59 +++++++------ test/test_examples.expected | 159 +++++++++++++++++------------------- test/test_examples.py | 2 +- 3 files changed, 105 insertions(+), 115 deletions(-) diff --git a/examples/jsd.py b/examples/jsd.py index 29d306ea7..8d90a5050 100644 --- a/examples/jsd.py +++ b/examples/jsd.py @@ -63,11 +63,14 @@ def jsd_forward( assert target.shape == _input.shape, ( f"Shape mismatch: {target.shape} != {_input.shape}" ) - n_rows = BT + block_size_n = hl.register_block_size(V) + block_size_m = hl.register_block_size(BT) # Create output tensor for accumulating loss - loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) - dX = torch.empty_like(_input) + loss = torch.zeros([BT], dtype=torch.float32, device=_input.device) + dX = torch.empty_like(loss) + + one_minus_beta = 1 - beta # Count non-ignored elements n_non_ignore = float(BT) @@ -79,60 +82,54 @@ def jsd_forward( ), torch.zeros_like(_input) # Process each sequence position - BT_SIZE = helion.cdiv(BT, n_rows) # The liger kernel uses 1 - for tile_bt in hl.tile(BT, block_size=BT_SIZE): + for tile_bt in hl.tile(BT, block_size=block_size_m): # Check for label masking if shift_labels is not None: if shift_labels[tile_bt] == ignore_index: for tile_X in hl.tile(V): dX[tile_bt, tile_X] = 0.0 continue - - for tile_v in hl.tile(V): + intermediate_loss = hl.zeros([tile_bt, block_size_n], dtype=torch.float32) + intermediate_dX = hl.zeros([tile_bt, block_size_n], dtype=_input.dtype) + for tile_v in hl.tile(V, block_size=block_size_n): # Load log probabilities and convert to float32 X = _input[tile_bt, tile_v] Y = target[tile_bt, tile_v] - X_max = torch.amax(X, dim=0) - Y_max = torch.amax(Y, dim=0) if beta == 0.0: # Forward KL: KL(P || Q) + Y_max = torch.amax(Y, dim=0) Y_shift = Y - Y_max Y_prob = torch.exp(Y_shift) * torch.exp( Y_max ) # Compensate for the shift - loss[tile_bt, tile_v] = Y_prob * (Y - X) - dX[tile_bt, tile_v] = -Y_prob + intermediate_loss += Y_prob * (Y - X) + intermediate_dX += -Y_prob elif beta == 1.0: # Reverse KL: KL(Q || P) + X_max = torch.amax(X, dim=0) X_shift = X - X_max X_prob = torch.exp(X_shift) * torch.exp( X_max ) # Compensate for the shift - loss[tile_bt, tile_v] = X_prob * (X - Y) - dX[tile_bt, tile_v] = loss[tile_bt, tile_v] + X_prob + intermediate_loss += X_prob * (X - Y) + intermediate_dX += intermediate_loss + X_prob else: # General JSD: beta*KL(P||M) + (1-beta)*KL(Q||M) - max_val = torch.maximum(X_max, Y_max) - X_shifted = X - max_val - Y_shifted = Y - max_val - - exp_max = torch.exp(max_val) - - Q = torch.exp(X_shifted) * exp_max # = exp(X) - P = torch.exp(Y_shifted) * exp_max # = exp(Y) + Q = torch.exp(X) # = exp(X) + P = torch.exp(Y) # = exp(Y) beta_P = beta * P - one_minus_beta_Q = (1 - beta) * Q + one_minus_beta_Q = one_minus_beta * Q M = beta_P + one_minus_beta_Q - log_M = torch.log( - M - ) # No need to compensate as M is already in original scale + log_M = torch.log(M) + x_minus_log_m = X - log_M + kl_q_m = one_minus_beta_Q * x_minus_log_m - loss[tile_bt, tile_v] = beta_P * Y + one_minus_beta_Q * X - M * log_M - dX[tile_bt, tile_v] = one_minus_beta_Q * (X - log_M) + intermediate_loss += beta_P * (Y - log_M) + kl_q_m + intermediate_dX += kl_q_m - # Accumulate over vocabulary dimension - scale = 1.0 / n_non_ignore - loss[tile_bt, tile_v] = loss[tile_bt, tile_v] * scale - dX[tile_bt, tile_v] = dX[tile_bt, tile_v] * scale + # Accumulate over vocabulary dimension + scale = 1.0 / n_non_ignore + loss[tile_bt] = torch.sum(intermediate_loss * scale, dim=1) + dX[tile_bt] = torch.sum(intermediate_dX * scale, dim=1) # Normalize by number of non-ignored elements, run it on host to match liger_kernel final_loss = torch.sum( diff --git a/test/test_examples.expected b/test/test_examples.expected index 3e2bc61f0..0ad5bd037 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2121,121 +2121,116 @@ def jagged_sum_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launche from __future__ import annotations import torch -import helion import triton import triton.language as tl -from torch._inductor.runtime import triton_helpers from torch._inductor.runtime.triton_helpers import math as tl_math from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, dX_stride_1, loss_stride_0, loss_stride_1, target_stride_0, target_stride_1, BT, V, beta, n_non_ignore, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): +def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, loss_stride_0, target_stride_0, target_stride_1, BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) - offset_0 = pid_0 * _BLOCK_SIZE_0 - indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - mask_0 = indices_0 < BT - for offset_1 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_1): - indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) - mask_1 = indices_1 < V - X = tl.load(_input + (indices_0[:, None] * _input_stride_0 + indices_1[None, :] * _input_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - Y = tl.load(target + (indices_0[:, None] * target_stride_0 + indices_1[None, :] * target_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], X, tl.full([], float('-inf'), tl.float32)) - X_max = tl.cast(tl.max(_mask_to, 0), tl.float32) - _mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], Y, tl.full([], float('-inf'), tl.float32)) - Y_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32) + offset_1 = pid_0 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < BT + intermediate_loss = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) + intermediate_dX = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) + for offset_0 in tl.range(0, V.to(tl.int32)): + indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32) + intermediate_loss_copy = intermediate_loss + intermediate_dX_copy = intermediate_dX + intermediate_loss = intermediate_loss_copy + intermediate_dX = intermediate_dX_copy + X = tl.load(_input + (indices_1[:, None] * _input_stride_0 + indices_0[None, :] * _input_stride_1), mask_1[:, None], other=0) + Y = tl.load(target + (indices_1[:, None] * target_stride_0 + indices_0[None, :] * target_stride_1), mask_1[:, None], other=0) eq = beta == 0.0 if eq: Y_copy = Y - Y_max_copy = Y_max X_copy = X + intermediate_loss_copy_0_copy = intermediate_loss + intermediate_dX_copy_0_copy = intermediate_dX Y_copy_0 = Y_copy - Y_max_copy_0 = Y_max_copy X_copy_0 = X_copy - v_0 = Y_max_copy_0[None, :] + intermediate_loss_copy_0_copy_0 = intermediate_loss_copy_0_copy + intermediate_dX_copy_0_copy_0 = intermediate_dX_copy_0_copy + _mask_to = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), Y_copy_0, tl.full([], float('-inf'), tl.float32)) + Y_max = tl.cast(tl.max(_mask_to, 0), tl.float32) + v_0 = Y_max[None, :] v_1 = Y_copy_0 - v_0 v_2 = libdevice.exp(v_1) - v_3 = libdevice.exp(Y_max_copy_0) + v_3 = libdevice.exp(Y_max) v_4 = v_3[None, :] v_5 = v_2 * v_4 v_6 = Y_copy_0 - X_copy_0 v_7 = v_5 * v_6 - tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_7, mask_0[:, None] & mask_1[None, :]) - v_8 = -v_5 - tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_8, mask_0[:, None] & mask_1[None, :]) + intermediate_loss = intermediate_loss_copy_0_copy_0 + v_7 + v_9 = -v_5 + intermediate_dX = intermediate_dX_copy_0_copy_0 + v_9 _not = not eq if _not: X_copy_1 = X - X_max_copy = X_max Y_copy_1 = Y - Y_max_copy_1 = Y_max + intermediate_loss_copy_0_copy_1 = intermediate_loss + intermediate_dX_copy_0_copy_1 = intermediate_dX X_copy_1_0 = X_copy_1 - X_max_copy_0 = X_max_copy Y_copy_1_0 = Y_copy_1 - Y_max_copy_1_0 = Y_max_copy_1 + intermediate_loss = intermediate_loss_copy_0_copy_1 + intermediate_dX = intermediate_dX_copy_0_copy_1 eq_1 = beta == 1.0 if eq_1: X_copy_1_0_copy = X_copy_1_0 - X_max_copy_0_copy = X_max_copy_0 Y_copy_1_0_copy = Y_copy_1_0 + intermediate_loss_copy_0_copy_1_0_copy = intermediate_loss + intermediate_dX_copy_0_copy_1_0_copy = intermediate_dX X_copy_1_0_copy_0 = X_copy_1_0_copy - X_max_copy_0_copy_0 = X_max_copy_0_copy Y_copy_1_0_copy_0 = Y_copy_1_0_copy - v_9 = X_max_copy_0_copy_0[None, :] - v_10 = X_copy_1_0_copy_0 - v_9 - v_11 = libdevice.exp(v_10) - v_12 = libdevice.exp(X_max_copy_0_copy_0) - v_13 = v_12[None, :] - v_14 = v_11 * v_13 - v_15 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0 - v_16 = v_14 * v_15 - tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_16, mask_0[:, None] & mask_1[None, :]) - load = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - v_17 = load + v_14 - tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_17, mask_0[:, None] & mask_1[None, :]) + intermediate_loss_copy_0_copy_1_0_copy_0 = intermediate_loss_copy_0_copy_1_0_copy + intermediate_dX_copy_0_copy_1_0_copy_0 = intermediate_dX_copy_0_copy_1_0_copy + _mask_to_1 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), X_copy_1_0_copy_0, tl.full([], float('-inf'), tl.float32)) + X_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32) + v_11 = X_max[None, :] + v_12 = X_copy_1_0_copy_0 - v_11 + v_13 = libdevice.exp(v_12) + v_14 = libdevice.exp(X_max) + v_15 = v_14[None, :] + v_16 = v_13 * v_15 + v_17 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0 + v_18 = v_16 * v_17 + intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_0 + v_18 + v_20 = intermediate_loss + v_16 + intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_0 + v_20 _not_1 = not eq_1 if _not_1: - X_max_copy_0_copy_1 = X_max_copy_0 - Y_max_copy_1_0_copy = Y_max_copy_1_0 X_copy_1_0_copy_1 = X_copy_1_0 Y_copy_1_0_copy_1 = Y_copy_1_0 - X_max_copy_0_copy_1_0 = X_max_copy_0_copy_1 - Y_max_copy_1_0_copy_0 = Y_max_copy_1_0_copy + intermediate_loss_copy_0_copy_1_0_copy_1 = intermediate_loss + intermediate_dX_copy_0_copy_1_0_copy_1 = intermediate_dX X_copy_1_0_copy_1_0 = X_copy_1_0_copy_1 Y_copy_1_0_copy_1_0 = Y_copy_1_0_copy_1 - v_18 = triton_helpers.maximum(X_max_copy_0_copy_1_0, Y_max_copy_1_0_copy_0) - v_19 = v_18[None, :] - v_20 = X_copy_1_0_copy_1_0 - v_19 - v_21 = v_18[None, :] - v_22 = Y_copy_1_0_copy_1_0 - v_21 - v_23 = libdevice.exp(v_18) - v_24 = libdevice.exp(v_20) - v_25 = v_23[None, :] - v_26 = v_24 * v_25 - v_27 = libdevice.exp(v_22) - v_28 = v_23[None, :] - v_29 = v_27 * v_28 - v_30 = v_29 * beta - sub_2 = 1.0 + -1 * beta - v_31 = v_26 * sub_2 - v_32 = v_30 + v_31 - v_33 = tl_math.log(v_32) - v_34 = v_30 * Y_copy_1_0_copy_1_0 - v_35 = v_31 * X_copy_1_0_copy_1_0 - v_36 = v_34 + v_35 - v_37 = v_32 * v_33 - v_38 = v_36 - v_37 - tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_38, mask_0[:, None] & mask_1[None, :]) - v_39 = X_copy_1_0_copy_1_0 - v_33 - v_40 = v_31 * v_39 - tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_40, mask_0[:, None] & mask_1[None, :]) - truediv = 1.0 / n_non_ignore - load_2 = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - v_41 = load_2 * truediv - tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_41, mask_0[:, None] & mask_1[None, :]) - load_3 = tl.load(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - v_42 = load_3 * truediv - tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_42, mask_0[:, None] & mask_1[None, :]) + intermediate_loss_copy_0_copy_1_0_copy_1_0 = intermediate_loss_copy_0_copy_1_0_copy_1 + intermediate_dX_copy_0_copy_1_0_copy_1_0 = intermediate_dX_copy_0_copy_1_0_copy_1 + v_22 = libdevice.exp(X_copy_1_0_copy_1_0) + v_23 = libdevice.exp(Y_copy_1_0_copy_1_0) + v_24 = v_23 * beta + v_25 = v_22 * one_minus_beta + v_26 = v_24 + v_25 + v_27 = tl_math.log(v_26) + v_28 = X_copy_1_0_copy_1_0 - v_27 + v_29 = v_25 * v_28 + v_30 = Y_copy_1_0_copy_1_0 - v_27 + v_31 = v_24 * v_30 + v_32 = v_31 + v_29 + intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_1_0 + v_32 + intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_1_0 + v_29 + truediv = 1.0 / n_non_ignore + v_35 = intermediate_loss * truediv + _mask_to_2 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_35, tl.full([], 0, tl.float32)) + sum_1 = tl.cast(tl.sum(_mask_to_2, 1), tl.float32) + tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1) + v_36 = intermediate_dX * truediv + _mask_to_3 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_36, tl.full([], 0, tl.float32)) + sum_2 = tl.cast(tl.sum(_mask_to_3, 1), tl.float32) + tl.store(dX + indices_1 * dX_stride_0, sum_2, mask_1) def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None, beta: float=0.5, ignore_index: int=-100, *, _launcher=_default_launcher): """ @@ -2254,18 +2249,16 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None """ BT, V = _input.shape assert target.shape == _input.shape, f'Shape mismatch: {target.shape} != {_input.shape}' - n_rows = BT - loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) - dX = torch.empty_like(_input) + loss = torch.zeros([BT], dtype=torch.float32, device=_input.device) + dX = torch.empty_like(loss) + one_minus_beta = 1 - beta n_non_ignore = float(BT) if shift_labels is not None: n_non_ignore = float((shift_labels != ignore_index).sum().item()) if n_non_ignore == 0: return (torch.zeros([], dtype=_input.dtype, device=_input.device), torch.zeros_like(_input)) - BT_SIZE = helion.cdiv(BT, n_rows) - _BLOCK_SIZE_0 = BT_SIZE _BLOCK_SIZE_1 = 4096 - _launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_0),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), dX.stride(1), loss.stride(0), loss.stride(1), target.stride(0), target.stride(1), BT, V, beta, n_non_ignore, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + _launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), loss.stride(0), target.stride(0), target.stride(1), BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3) final_loss = torch.sum(loss) return (final_loss, dX) diff --git a/test/test_examples.py b/test/test_examples.py index b0a7df3d1..81015cee6 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1116,7 +1116,7 @@ def test_jsd(self): args, (expected(*args), None), fn_name="jsd_forward", - block_sizes=[4096], + block_sizes=[1, 4096], num_warps=4, num_stages=3, )