Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions examples/jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@ def jsd_forward(
assert target.shape == _input.shape, (
f"Shape mismatch: {target.shape} != {_input.shape}"
)
n_rows = BT
block_size_n = hl.register_block_size(V)
block_size_m = hl.register_block_size(BT)

# Create output tensor for accumulating loss
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
dX = torch.empty_like(_input)
loss = torch.zeros([BT], dtype=torch.float32, device=_input.device)
dX = torch.empty_like(loss)

one_minus_beta = 1 - beta

# Count non-ignored elements
n_non_ignore = float(BT)
Expand All @@ -79,60 +82,54 @@ def jsd_forward(
), torch.zeros_like(_input)

# Process each sequence position
BT_SIZE = helion.cdiv(BT, n_rows) # The liger kernel uses 1
for tile_bt in hl.tile(BT, block_size=BT_SIZE):
for tile_bt in hl.tile(BT, block_size=block_size_m):
# Check for label masking
if shift_labels is not None:
if shift_labels[tile_bt] == ignore_index:
for tile_X in hl.tile(V):
dX[tile_bt, tile_X] = 0.0
continue

for tile_v in hl.tile(V):
intermediate_loss = hl.zeros([tile_bt, block_size_n], dtype=torch.float32)
intermediate_dX = hl.zeros([tile_bt, block_size_n], dtype=_input.dtype)
for tile_v in hl.tile(V, block_size=block_size_n):
# Load log probabilities and convert to float32
X = _input[tile_bt, tile_v]
Y = target[tile_bt, tile_v]
X_max = torch.amax(X, dim=0)
Y_max = torch.amax(Y, dim=0)

if beta == 0.0: # Forward KL: KL(P || Q)
Y_max = torch.amax(Y, dim=0)
Y_shift = Y - Y_max
Y_prob = torch.exp(Y_shift) * torch.exp(
Y_max
) # Compensate for the shift
loss[tile_bt, tile_v] = Y_prob * (Y - X)
dX[tile_bt, tile_v] = -Y_prob
intermediate_loss += Y_prob * (Y - X)
intermediate_dX += -Y_prob
elif beta == 1.0: # Reverse KL: KL(Q || P)
X_max = torch.amax(X, dim=0)
X_shift = X - X_max
X_prob = torch.exp(X_shift) * torch.exp(
X_max
) # Compensate for the shift
loss[tile_bt, tile_v] = X_prob * (X - Y)
dX[tile_bt, tile_v] = loss[tile_bt, tile_v] + X_prob
intermediate_loss += X_prob * (X - Y)
intermediate_dX += intermediate_loss + X_prob
else: # General JSD: beta*KL(P||M) + (1-beta)*KL(Q||M)
max_val = torch.maximum(X_max, Y_max)
X_shifted = X - max_val
Y_shifted = Y - max_val

exp_max = torch.exp(max_val)

Q = torch.exp(X_shifted) * exp_max # = exp(X)
P = torch.exp(Y_shifted) * exp_max # = exp(Y)
Q = torch.exp(X) # = exp(X)
P = torch.exp(Y) # = exp(Y)

beta_P = beta * P
one_minus_beta_Q = (1 - beta) * Q
one_minus_beta_Q = one_minus_beta * Q
M = beta_P + one_minus_beta_Q
log_M = torch.log(
M
) # No need to compensate as M is already in original scale
log_M = torch.log(M)
x_minus_log_m = X - log_M
kl_q_m = one_minus_beta_Q * x_minus_log_m

loss[tile_bt, tile_v] = beta_P * Y + one_minus_beta_Q * X - M * log_M
dX[tile_bt, tile_v] = one_minus_beta_Q * (X - log_M)
intermediate_loss += beta_P * (Y - log_M) + kl_q_m
intermediate_dX += kl_q_m

# Accumulate over vocabulary dimension
scale = 1.0 / n_non_ignore
loss[tile_bt, tile_v] = loss[tile_bt, tile_v] * scale
dX[tile_bt, tile_v] = dX[tile_bt, tile_v] * scale
# Accumulate over vocabulary dimension
scale = 1.0 / n_non_ignore
loss[tile_bt] = torch.sum(intermediate_loss * scale, dim=1)
dX[tile_bt] = torch.sum(intermediate_dX * scale, dim=1)

# Normalize by number of non-ignored elements, run it on host to match liger_kernel
final_loss = torch.sum(
Expand Down
159 changes: 76 additions & 83 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2121,121 +2121,116 @@ def jagged_sum_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launche
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, dX_stride_1, loss_stride_0, loss_stride_1, target_stride_0, target_stride_1, BT, V, beta, n_non_ignore, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, loss_stride_0, target_stride_0, target_stride_1, BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < BT
for offset_1 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < V
X = tl.load(_input + (indices_0[:, None] * _input_stride_0 + indices_1[None, :] * _input_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
Y = tl.load(target + (indices_0[:, None] * target_stride_0 + indices_1[None, :] * target_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], X, tl.full([], float('-inf'), tl.float32))
X_max = tl.cast(tl.max(_mask_to, 0), tl.float32)
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], Y, tl.full([], float('-inf'), tl.float32))
Y_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32)
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < BT
intermediate_loss = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
intermediate_dX = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
for offset_0 in tl.range(0, V.to(tl.int32)):
indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
intermediate_loss_copy = intermediate_loss
intermediate_dX_copy = intermediate_dX
intermediate_loss = intermediate_loss_copy
intermediate_dX = intermediate_dX_copy
X = tl.load(_input + (indices_1[:, None] * _input_stride_0 + indices_0[None, :] * _input_stride_1), mask_1[:, None], other=0)
Y = tl.load(target + (indices_1[:, None] * target_stride_0 + indices_0[None, :] * target_stride_1), mask_1[:, None], other=0)
eq = beta == 0.0
if eq:
Y_copy = Y
Y_max_copy = Y_max
X_copy = X
intermediate_loss_copy_0_copy = intermediate_loss
intermediate_dX_copy_0_copy = intermediate_dX
Y_copy_0 = Y_copy
Y_max_copy_0 = Y_max_copy
X_copy_0 = X_copy
v_0 = Y_max_copy_0[None, :]
intermediate_loss_copy_0_copy_0 = intermediate_loss_copy_0_copy
intermediate_dX_copy_0_copy_0 = intermediate_dX_copy_0_copy
_mask_to = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), Y_copy_0, tl.full([], float('-inf'), tl.float32))
Y_max = tl.cast(tl.max(_mask_to, 0), tl.float32)
v_0 = Y_max[None, :]
v_1 = Y_copy_0 - v_0
v_2 = libdevice.exp(v_1)
v_3 = libdevice.exp(Y_max_copy_0)
v_3 = libdevice.exp(Y_max)
v_4 = v_3[None, :]
v_5 = v_2 * v_4
v_6 = Y_copy_0 - X_copy_0
v_7 = v_5 * v_6
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_7, mask_0[:, None] & mask_1[None, :])
v_8 = -v_5
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_8, mask_0[:, None] & mask_1[None, :])
intermediate_loss = intermediate_loss_copy_0_copy_0 + v_7
v_9 = -v_5
intermediate_dX = intermediate_dX_copy_0_copy_0 + v_9
_not = not eq
if _not:
X_copy_1 = X
X_max_copy = X_max
Y_copy_1 = Y
Y_max_copy_1 = Y_max
intermediate_loss_copy_0_copy_1 = intermediate_loss
intermediate_dX_copy_0_copy_1 = intermediate_dX
X_copy_1_0 = X_copy_1
X_max_copy_0 = X_max_copy
Y_copy_1_0 = Y_copy_1
Y_max_copy_1_0 = Y_max_copy_1
intermediate_loss = intermediate_loss_copy_0_copy_1
intermediate_dX = intermediate_dX_copy_0_copy_1
eq_1 = beta == 1.0
if eq_1:
X_copy_1_0_copy = X_copy_1_0
X_max_copy_0_copy = X_max_copy_0
Y_copy_1_0_copy = Y_copy_1_0
intermediate_loss_copy_0_copy_1_0_copy = intermediate_loss
intermediate_dX_copy_0_copy_1_0_copy = intermediate_dX
X_copy_1_0_copy_0 = X_copy_1_0_copy
X_max_copy_0_copy_0 = X_max_copy_0_copy
Y_copy_1_0_copy_0 = Y_copy_1_0_copy
v_9 = X_max_copy_0_copy_0[None, :]
v_10 = X_copy_1_0_copy_0 - v_9
v_11 = libdevice.exp(v_10)
v_12 = libdevice.exp(X_max_copy_0_copy_0)
v_13 = v_12[None, :]
v_14 = v_11 * v_13
v_15 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0
v_16 = v_14 * v_15
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_16, mask_0[:, None] & mask_1[None, :])
load = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_17 = load + v_14
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_17, mask_0[:, None] & mask_1[None, :])
intermediate_loss_copy_0_copy_1_0_copy_0 = intermediate_loss_copy_0_copy_1_0_copy
intermediate_dX_copy_0_copy_1_0_copy_0 = intermediate_dX_copy_0_copy_1_0_copy
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), X_copy_1_0_copy_0, tl.full([], float('-inf'), tl.float32))
X_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32)
v_11 = X_max[None, :]
v_12 = X_copy_1_0_copy_0 - v_11
v_13 = libdevice.exp(v_12)
v_14 = libdevice.exp(X_max)
v_15 = v_14[None, :]
v_16 = v_13 * v_15
v_17 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0
v_18 = v_16 * v_17
intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_0 + v_18
v_20 = intermediate_loss + v_16
intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_0 + v_20
_not_1 = not eq_1
if _not_1:
X_max_copy_0_copy_1 = X_max_copy_0
Y_max_copy_1_0_copy = Y_max_copy_1_0
X_copy_1_0_copy_1 = X_copy_1_0
Y_copy_1_0_copy_1 = Y_copy_1_0
X_max_copy_0_copy_1_0 = X_max_copy_0_copy_1
Y_max_copy_1_0_copy_0 = Y_max_copy_1_0_copy
intermediate_loss_copy_0_copy_1_0_copy_1 = intermediate_loss
intermediate_dX_copy_0_copy_1_0_copy_1 = intermediate_dX
X_copy_1_0_copy_1_0 = X_copy_1_0_copy_1
Y_copy_1_0_copy_1_0 = Y_copy_1_0_copy_1
v_18 = triton_helpers.maximum(X_max_copy_0_copy_1_0, Y_max_copy_1_0_copy_0)
v_19 = v_18[None, :]
v_20 = X_copy_1_0_copy_1_0 - v_19
v_21 = v_18[None, :]
v_22 = Y_copy_1_0_copy_1_0 - v_21
v_23 = libdevice.exp(v_18)
v_24 = libdevice.exp(v_20)
v_25 = v_23[None, :]
v_26 = v_24 * v_25
v_27 = libdevice.exp(v_22)
v_28 = v_23[None, :]
v_29 = v_27 * v_28
v_30 = v_29 * beta
sub_2 = 1.0 + -1 * beta
v_31 = v_26 * sub_2
v_32 = v_30 + v_31
v_33 = tl_math.log(v_32)
v_34 = v_30 * Y_copy_1_0_copy_1_0
v_35 = v_31 * X_copy_1_0_copy_1_0
v_36 = v_34 + v_35
v_37 = v_32 * v_33
v_38 = v_36 - v_37
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_38, mask_0[:, None] & mask_1[None, :])
v_39 = X_copy_1_0_copy_1_0 - v_33
v_40 = v_31 * v_39
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_40, mask_0[:, None] & mask_1[None, :])
truediv = 1.0 / n_non_ignore
load_2 = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_41 = load_2 * truediv
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_41, mask_0[:, None] & mask_1[None, :])
load_3 = tl.load(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_42 = load_3 * truediv
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_42, mask_0[:, None] & mask_1[None, :])
intermediate_loss_copy_0_copy_1_0_copy_1_0 = intermediate_loss_copy_0_copy_1_0_copy_1
intermediate_dX_copy_0_copy_1_0_copy_1_0 = intermediate_dX_copy_0_copy_1_0_copy_1
v_22 = libdevice.exp(X_copy_1_0_copy_1_0)
v_23 = libdevice.exp(Y_copy_1_0_copy_1_0)
v_24 = v_23 * beta
v_25 = v_22 * one_minus_beta
v_26 = v_24 + v_25
v_27 = tl_math.log(v_26)
v_28 = X_copy_1_0_copy_1_0 - v_27
v_29 = v_25 * v_28
v_30 = Y_copy_1_0_copy_1_0 - v_27
v_31 = v_24 * v_30
v_32 = v_31 + v_29
intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_1_0 + v_32
intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_1_0 + v_29
truediv = 1.0 / n_non_ignore
v_35 = intermediate_loss * truediv
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_35, tl.full([], 0, tl.float32))
sum_1 = tl.cast(tl.sum(_mask_to_2, 1), tl.float32)
tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)
v_36 = intermediate_dX * truediv
_mask_to_3 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_36, tl.full([], 0, tl.float32))
sum_2 = tl.cast(tl.sum(_mask_to_3, 1), tl.float32)
tl.store(dX + indices_1 * dX_stride_0, sum_2, mask_1)

def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None, beta: float=0.5, ignore_index: int=-100, *, _launcher=_default_launcher):
"""
Expand All @@ -2254,18 +2249,16 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
"""
BT, V = _input.shape
assert target.shape == _input.shape, f'Shape mismatch: {target.shape} != {_input.shape}'
n_rows = BT
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
dX = torch.empty_like(_input)
loss = torch.zeros([BT], dtype=torch.float32, device=_input.device)
dX = torch.empty_like(loss)
one_minus_beta = 1 - beta
n_non_ignore = float(BT)
if shift_labels is not None:
n_non_ignore = float((shift_labels != ignore_index).sum().item())
if n_non_ignore == 0:
return (torch.zeros([], dtype=_input.dtype, device=_input.device), torch.zeros_like(_input))
BT_SIZE = helion.cdiv(BT, n_rows)
_BLOCK_SIZE_0 = BT_SIZE
_BLOCK_SIZE_1 = 4096
_launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_0),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), dX.stride(1), loss.stride(0), loss.stride(1), target.stride(0), target.stride(1), BT, V, beta, n_non_ignore, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
_launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), loss.stride(0), target.stride(0), target.stride(1), BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3)
final_loss = torch.sum(loss)
return (final_loss, dX)

Expand Down
2 changes: 1 addition & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ def test_jsd(self):
args,
(expected(*args), None),
fn_name="jsd_forward",
block_sizes=[4096],
block_sizes=[1, 4096],
num_warps=4,
num_stages=3,
)
Expand Down
Loading