Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,11 +522,14 @@ def mark_alternate_size(self, size: torch.SymInt | int | None) -> None:
self.size = size
if size is not None:
env = CompileEnvironment.current()
# Refresh the var_to_val hint to match the resolved block size
hint = env.size_hint(size)
env.shape_env.var_to_val[self.symbol()] = sympy.Integer(hint)
with contextlib.suppress(KeyError):
# update the size hint now that we know the size
env.config_spec.block_sizes.block_id_lookup(
self.block_id
).update_hint(env.size_hint(size))
).update_hint(hint)
elif size is None or self.size is None or self.size != size:
self.size = None

Expand Down
10 changes: 7 additions & 3 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,10 @@ def create(
output_size = SubscriptIndexing.compute_shape(fake_value, index)
env = CompileEnvironment.current()
dtype = env.triton_index_type()

def _is_size_one(size: int | torch.SymInt) -> bool:
return env.known_equal(size, 1)

for n, k in enumerate(index):
if k is None:
output_idx += 1
Expand All @@ -544,7 +548,7 @@ def create(
index_values.append(f"({index_var}){expand}")
if (
mask := state.codegen.mask_var(origin.origin.block_id)
) and fake_value.size(i) != 1:
) and not _is_size_one(fake_value.size(i)):
mask_values.setdefault(f"({mask}){expand}")
output_idx += 1
else:
Expand Down Expand Up @@ -576,7 +580,7 @@ def create(
index_values.append(f"{start}{expand}")
else:
# Full slice or slice without step
if size != 1:
if not _is_size_one(size):
rdim = env.allocate_reduction_dimension(size)
block_idx = rdim.block_id
index_var = state.codegen.index_var(block_idx)
Expand Down Expand Up @@ -620,7 +624,7 @@ def create(
assert len(index_values) == fake_value.ndim
index_expr = []
for i, idx in enumerate(index_values):
if fake_value.size(i) != 1:
if not _is_size_one(fake_value.size(i)):
stride = state.device_function.tensor_stride(fake_value, i).name
index_expr.append(f"{idx} * {stride}")
if not index_expr:
Expand Down
11 changes: 11 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .compile_environment import FixedBlockSizeSource
from .compile_environment import LoopSpecBlockSizeSource
from .compile_environment import warning
from .device_function import contains_only_block_size_symbols
from .host_function import HostFunction
from .host_function import SymbolOrigin
from .output_header import library_imports
Expand Down Expand Up @@ -473,6 +474,16 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
if self.origin.is_device():
output_sizes.append(output_size)
elif output_size != 1:
# If all symbols in output_size are block size symbols, we reuse them
if isinstance(output_size, torch.SymInt):
expr = output_size._sympy_()
if (
isinstance(expr, sympy.Expr)
and expr.free_symbols
and contains_only_block_size_symbols(expr)
):
output_sizes.append(output_size)
continue
rdim = CompileEnvironment.current().allocate_reduction_dimension(
output_size
)
Expand Down
4 changes: 3 additions & 1 deletion helion/_compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ def compute_slice_size(
step = slice_obj.step
return (stop - start + step - 1) // step
# Full slice or slice without step
return original_size
start = slice_obj.start if slice_obj.start is not None else 0
stop = slice_obj.stop if slice_obj.stop is not None else original_size
return stop - start
8 changes: 4 additions & 4 deletions test/test_constexpr.expected
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2
offset_2 = pid_1 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
for offset_3 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True):
indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32)
for offset_0 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True):
indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
acc_copy = acc
acc_copy_0 = acc_copy
packed = tl.load(B + (indices_3[:, None] * 16 + indices_2[None, :] * 1), None)
packed = tl.load(B + (indices_0[:, None] * 16 + indices_2[None, :] * 1), None)
v_0 = tl.full([], 4, tl.int8)
v_1 = packed << v_0
v_2 = tl.full([], 4, tl.int8)
Expand All @@ -54,7 +54,7 @@ def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2
mask_1 = broadcast_idx == 1
stacked_result = tl.where(mask_1, expanded_1, stacked_result)
unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
mul_5 = 2 * offset_3
mul_5 = 2 * offset_0
iota = mul_5 + tl.arange(0, mul)
a_tile = tl.load(A + (indices_1[:, None] * 32 + iota[None, :] * 1), None)
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
Expand Down
25 changes: 12 additions & 13 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1353,12 +1353,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
mask_2 = indices_2 < N
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
floordiv = triton_helpers.div_floor_integer(K, 2)
for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_3 < floordiv
for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < floordiv
acc_copy = acc
acc_copy_0 = acc_copy
b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
v_0 = tl.full([], 4, tl.int8)
v_1 = b_tile << v_0
v_2 = tl.full([], 4, tl.int8)
Expand All @@ -1372,12 +1372,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
expanded_0 = tl.expand_dims(v_6, 1)
expanded_1 = tl.expand_dims(v_7, 1)
stacked_result = tl.zeros_like(expanded_0)
mask_4 = broadcast_idx == 0
stacked_result = tl.where(mask_4, expanded_0, stacked_result)
mask_5 = broadcast_idx == 1
stacked_result = tl.where(mask_5, expanded_1, stacked_result)
mask_3 = broadcast_idx == 0
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
mask_4 = broadcast_idx == 1
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
mul_5 = 2 * offset_3
mul_5 = 2 * offset_0
iota = mul_5 + tl.arange(0, mul)
a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
Expand Down Expand Up @@ -1406,7 +1406,6 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
_BLOCK_SIZE_1 = 64
_BLOCK_SIZE_2 = 32
_RDIM_SIZE_3 = triton.next_power_of_2(K)
_BLOCK_SIZE_0 = 64
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return C
Expand Down Expand Up @@ -3124,7 +3123,7 @@ import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_0, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
Expand Down Expand Up @@ -3162,7 +3161,7 @@ def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size
v_18 = tl.cast(v_17, tl.float16)
tl.store(grad_x + (indices_2[:, None] * grad_x_stride_0 + indices_3[None, :] * grad_x_stride_1), v_18, None)
tile_id = offset_0 // _BLOCK_SIZE_0
tl.store(grad_weight + indices_3 * grad_weight_stride_1, grad_w_m, None)
tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, None)

def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, rsqrt: torch.Tensor, *, _launcher=_default_launcher):
"""
Expand All @@ -3187,7 +3186,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_2 = 64
_launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)
_launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)
return (grad_x, grad_weight.sum(0).to(weight.dtype))

--- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dw)
Expand Down
29 changes: 16 additions & 13 deletions test/test_loops.expected
Original file line number Diff line number Diff line change
Expand Up @@ -982,20 +982,22 @@ import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
full = tl.full([64, 64], 0.0, tl.float32)
tl.store(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), full, None)
for offset_2 in tl.range(0, 128, _BLOCK_SIZE_3):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
tl.store(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), y_true_val, None)
load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_2[None, :] * 1), load_1, mask=None, sem='relaxed')
load = tl.load(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), None)
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
indices_6 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
mask_3 = indices_6 < _BLOCK_SIZE_0
full = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32)
tl.store(loss_sum + (indices_5[:, None] * loss_sum_stride_0 + indices_6[None, :] * 1), full, mask_3[None, :])
for offset_4 in tl.range(0, 128, _BLOCK_SIZE_0):
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_4[None, :] * 1), None)
tl.store(kl_loss + (indices_1[:, None] * 128 + indices_4[None, :] * 1), y_true_val, None)
load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_4[None, :] * 1), None)
tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), load_1, mask=None, sem='relaxed')
load = tl.load(loss_sum + (indices_5[:, None] * loss_sum_stride_0 + indices_6[None, :] * 1), mask_3[None, :], other=0)
sum_1 = tl.cast(tl.sum(load, 1), tl.float32)
tl.store(loss + indices_1 * 1, sum_1, None)

Expand All @@ -1008,8 +1010,9 @@ def kernel_fixed_block_size(y_pred: torch.Tensor, y_true: torch.Tensor, *, _laun
loss_sum = torch.zeros([BT_SIZE, block_size_n], dtype=torch.float32, device=y_pred.device)
_BLOCK_SIZE_1 = 64
_RDIM_SIZE_2 = 64
_BLOCK_SIZE_3 = 64
_launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
_BLOCK_SIZE_0 = 128
_RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_0)
_launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return torch.sum(loss) / BT

--- assertExpectedJournal(TestLoops.test_reorder_with_register_block_size)
Expand Down
3 changes: 1 addition & 2 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
self.assertEqual(spec.min_size, 32)
self.assertEqual(spec.max_size, 256)

@skipIfRefEager("Triton codegen is disabled in ref eager mode")
def test_register_block_size_codegen_size_hint(self):
@helion.kernel(static_shapes=True)
def kernel_fixed_block_size(
Expand Down Expand Up @@ -368,7 +367,7 @@ def kernel_fixed_block_size(
code, result = code_and_output(kernel_fixed_block_size, args, block_sizes=[128])
self.assertExpectedJournal(code)

expected = y_true[:, : y_pred.size(0)].sum() / y_pred.size(0)
expected = y_true[:, :].sum() / y_pred.size(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original expected value is buggy - to match the intended kernel behavior, it should have been a sum on y_true[:, :] not y_true[:, : y_pred.size(0)].

torch.testing.assert_close(result, expected)

def test_reorder_with_register_block_size(self):
Expand Down
52 changes: 52 additions & 0 deletions test/test_matmul.expected
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,58 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]]
_launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestMatmul.test_matmul_packed_rhs)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_matmul_with_packed_b(A, B, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_2: tl.constexpr):
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < M
offset_2 = pid_1 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
mask_2 = indices_2 < N
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
floordiv = triton_helpers.div_floor_integer(K, 2)
for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < floordiv
acc_copy = acc
acc_copy_0 = acc_copy
mul = 2 * offset_0
iota = mul + tl.arange(0, mul_2)
lhs = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
packed = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
stack_idx = tl.arange(0, 2)
broadcast_idx = stack_idx[None, :, None]
expanded_0 = tl.expand_dims(packed, 1)
expanded_1 = tl.expand_dims(packed, 1)
stacked_result = tl.zeros_like(expanded_0)
mask_3 = broadcast_idx == 0
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
mask_4 = broadcast_idx == 1
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
rhs = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
acc = tl.dot(tl.cast(lhs, tl.float32), tl.cast(rhs, tl.float32), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), acc, mask_1[:, None] & mask_2[None, :])

def matmul_with_packed_b(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, *, _launcher=_default_launcher):
M, K = A.shape
_, N = B.shape
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 16
_BLOCK_SIZE_0 = 16
_launcher(_helion_matmul_with_packed_b, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestMatmul.test_matmul_split_k)
from __future__ import annotations

Expand Down
Loading
Loading