Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def block_size_var(self, block_id: int) -> str | None:
var_name = self.new_var(f"_BLOCK_SIZE_{block_id}")
self.block_size_var_cache[key] = var_name
host_expr = HostFunction.current().literal_expr(block_value)
self.constexpr_arg(var_name, host_expr)
if self.constexpr_arg(var_name, host_expr):
self.codegen.host_statements.append(
statement_from_string(f"{var_name} = {host_expr}")
)

return self.block_size_var_cache[key]

Expand Down
71 changes: 71 additions & 0 deletions test/test_constexpr.expected
Original file line number Diff line number Diff line change
@@ -1,6 +1,77 @@
This file is automatically generated by assertExpectedJournal calls in test_constexpr.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestConstExpr.test_block_size_constexpr_assignment_in_host_code)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr):
total_pids = 16 * tl.cdiv(16, _BLOCK_SIZE_2)
block_size = tl.cdiv(total_pids, _NUM_SM)
start_pid = tl.program_id(0) * block_size
end_pid = tl.minimum(start_pid + block_size, total_pids)
for virtual_pid in tl.range(start_pid, end_pid, loop_unroll_factor=1, num_stages=3, flatten=True):
num_pid_m = 16
num_pid_n = tl.cdiv(16, _BLOCK_SIZE_2)
inner_2d_pid = virtual_pid
num_pid_in_group = 8 * num_pid_n
group_id = inner_2d_pid // num_pid_in_group
first_pid_m = group_id * 8
group_size_m = min(num_pid_m - first_pid_m, 8)
pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m
pid_1 = inner_2d_pid % num_pid_in_group // group_size_m
offset_1 = pid_0
indices_1 = offset_1 + tl.zeros([1], tl.int32)
offset_2 = pid_1 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
for offset_3 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True):
indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32)
acc_copy = acc
acc_copy_0 = acc_copy
packed = tl.load(B + (indices_3[:, None] * 16 + indices_2[None, :] * 1), None)
v_0 = tl.full([], 4, tl.int8)
v_1 = packed << v_0
v_2 = tl.full([], 4, tl.int8)
v_3 = v_1 >> v_2
v_4 = tl.full([], 4, tl.int8)
v_5 = packed >> v_4
v_6 = tl.cast(v_3, tl.bfloat16)
v_7 = tl.cast(v_5, tl.bfloat16)
stack_idx = tl.arange(0, 2)
broadcast_idx = stack_idx[None, :, None]
expanded_0 = tl.expand_dims(v_6, 1)
expanded_1 = tl.expand_dims(v_7, 1)
stacked_result = tl.zeros_like(expanded_0)
mask_0 = broadcast_idx == 0
stacked_result = tl.where(mask_0, expanded_0, stacked_result)
mask_1 = broadcast_idx == 1
stacked_result = tl.where(mask_1, expanded_1, stacked_result)
unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
mul_5 = 2 * offset_3
iota = mul_5 + tl.arange(0, mul)
a_tile = tl.load(A + (indices_1[:, None] * 32 + iota[None, :] * 1), None)
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
acc = acc_copy_0 + dot
v_9 = tl.cast(acc, tl.bfloat16)
tl.store(C + (indices_1[:, None] * 16 + indices_2[None, :] * 1), v_9, None)

def matmul_int4_block_expr(A: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
M, K = A.shape
_, N = B.shape
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
_NUM_SM = helion.runtime.get_num_sm(A.device)
_BLOCK_SIZE_2 = 16
_BLOCK_SIZE_0 = 1
_launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, 1, 1, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8)
return C

--- assertExpectedJournal(TestConstExpr.test_constexpr_float)
from __future__ import annotations

Expand Down
66 changes: 66 additions & 0 deletions test/test_constexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfRefEager
import helion.language as hl


Expand Down Expand Up @@ -92,6 +93,71 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor:
torch.testing.assert_close(result, x)
self.assertExpectedJournal(code)

@skipIfRefEager("Triton codegen does not work in ref eager mode")
def test_block_size_constexpr_assignment_in_host_code(self) -> None:
@helion.kernel(
config=helion.Config(
block_sizes=[1, 1, 16],
indexing="pointer",
l2_groupings=[8],
loop_orders=[[0, 1]],
num_stages=8,
num_warps=1,
pid_type="persistent_blocked",
range_flattens=[True, True],
range_multi_buffers=[None, False],
range_num_stages=[3, 1],
range_unroll_factors=[1, 4],
),
static_shapes=True,
)
def matmul_int4_block_expr(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
M, K = A.shape
_, N = B.shape

C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
block_size_k_packed = hl.register_block_size(K // 2)

for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)

for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
packed = B[tile_k_packed, tile_n]
lo = ((packed << 4) >> 4).to(torch.int8)
hi = (packed >> 4).to(torch.int8)
lo_bf16 = lo.to(torch.bfloat16)
hi_bf16 = hi.to(torch.bfloat16)
stacked = torch.stack([lo_bf16, hi_bf16], dim=1)
unpacked = stacked.reshape(
tile_k_packed.block_size * 2, tile_n.block_size
)

k_begin = tile_k_packed.begin * 2
k_len = tile_k_packed.block_size * 2
a_tile = A[tile_m, k_begin : (k_begin + k_len)]

acc = acc + hl.dot(a_tile, unpacked)

C[tile_m, tile_n] = acc.to(torch.bfloat16)

return C

M, K, N = 16, 32, 16
A = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)
B_unpacked = torch.randint(-8, 8, (K, N), dtype=torch.int8, device=DEVICE)
B_halves = B_unpacked.reshape(K // 2, 2, N).permute(1, 0, 2)
B_packed = ((B_halves[0] & 0xF) | (B_halves[1] << 4)).to(torch.int8)

bound = matmul_int4_block_expr.bind((A, B_packed))
(config,) = matmul_int4_block_expr.configs
code = bound.to_triton_code(config)
self.assertExpectedJournal(code)

device_code, host_code = code.split("def matmul_int4_block_expr(")
self.assertIn("_BLOCK_SIZE_0 = 1", host_code)
self.assertIn("2 * _BLOCK_SIZE_0, ", host_code)
self.assertIn("[2 * _BLOCK_SIZE_0, ", device_code)


if __name__ == "__main__":
unittest.main()
Loading