diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 924623464..cb04619e7 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -282,7 +282,10 @@ def block_size_var(self, block_id: int) -> str | None: var_name = self.new_var(f"_BLOCK_SIZE_{block_id}") self.block_size_var_cache[key] = var_name host_expr = HostFunction.current().literal_expr(block_value) - self.constexpr_arg(var_name, host_expr) + if self.constexpr_arg(var_name, host_expr): + self.codegen.host_statements.append( + statement_from_string(f"{var_name} = {host_expr}") + ) return self.block_size_var_cache[key] diff --git a/test/test_constexpr.expected b/test/test_constexpr.expected index a724b3091..1a3c225c5 100644 --- a/test/test_constexpr.expected +++ b/test/test_constexpr.expected @@ -1,6 +1,77 @@ This file is automatically generated by assertExpectedJournal calls in test_constexpr.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. +--- assertExpectedJournal(TestConstExpr.test_block_size_constexpr_assignment_in_host_code) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr): + total_pids = 16 * tl.cdiv(16, _BLOCK_SIZE_2) + block_size = tl.cdiv(total_pids, _NUM_SM) + start_pid = tl.program_id(0) * block_size + end_pid = tl.minimum(start_pid + block_size, total_pids) + for virtual_pid in tl.range(start_pid, end_pid, loop_unroll_factor=1, num_stages=3, flatten=True): + num_pid_m = 16 + num_pid_n = tl.cdiv(16, _BLOCK_SIZE_2) + inner_2d_pid = virtual_pid + num_pid_in_group = 8 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 8 + group_size_m = min(num_pid_m - first_pid_m, 8) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_1 = pid_0 + indices_1 = offset_1 + tl.zeros([1], tl.int32) + offset_2 = pid_1 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) + for offset_3 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True): + indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + packed = tl.load(B + (indices_3[:, None] * 16 + indices_2[None, :] * 1), None) + v_0 = tl.full([], 4, tl.int8) + v_1 = packed << v_0 + v_2 = tl.full([], 4, tl.int8) + v_3 = v_1 >> v_2 + v_4 = tl.full([], 4, tl.int8) + v_5 = packed >> v_4 + v_6 = tl.cast(v_3, tl.bfloat16) + v_7 = tl.cast(v_5, tl.bfloat16) + stack_idx = tl.arange(0, 2) + broadcast_idx = stack_idx[None, :, None] + expanded_0 = tl.expand_dims(v_6, 1) + expanded_1 = tl.expand_dims(v_7, 1) + stacked_result = tl.zeros_like(expanded_0) + mask_0 = broadcast_idx == 0 + stacked_result = tl.where(mask_0, expanded_0, stacked_result) + mask_1 = broadcast_idx == 1 + stacked_result = tl.where(mask_1, expanded_1, stacked_result) + unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2]) + mul_5 = 2 * offset_3 + iota = mul_5 + tl.arange(0, mul) + a_tile = tl.load(A + (indices_1[:, None] * 32 + iota[None, :] * 1), None) + dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + v_9 = tl.cast(acc, tl.bfloat16) + tl.store(C + (indices_1[:, None] * 16 + indices_2[None, :] * 1), v_9, None) + +def matmul_int4_block_expr(A: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): + M, K = A.shape + _, N = B.shape + C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) + _NUM_SM = helion.runtime.get_num_sm(A.device) + _BLOCK_SIZE_2 = 16 + _BLOCK_SIZE_0 = 1 + _launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, 1, 1, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8) + return C + --- assertExpectedJournal(TestConstExpr.test_constexpr_float) from __future__ import annotations diff --git a/test/test_constexpr.py b/test/test_constexpr.py index c030ee23c..d7c1d99b4 100644 --- a/test/test_constexpr.py +++ b/test/test_constexpr.py @@ -9,6 +9,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfRefEager import helion.language as hl @@ -92,6 +93,71 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor: torch.testing.assert_close(result, x) self.assertExpectedJournal(code) + @skipIfRefEager("Triton codegen does not work in ref eager mode") + def test_block_size_constexpr_assignment_in_host_code(self) -> None: + @helion.kernel( + config=helion.Config( + block_sizes=[1, 1, 16], + indexing="pointer", + l2_groupings=[8], + loop_orders=[[0, 1]], + num_stages=8, + num_warps=1, + pid_type="persistent_blocked", + range_flattens=[True, True], + range_multi_buffers=[None, False], + range_num_stages=[3, 1], + range_unroll_factors=[1, 4], + ), + static_shapes=True, + ) + def matmul_int4_block_expr(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + M, K = A.shape + _, N = B.shape + + C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) + block_size_k_packed = hl.register_block_size(K // 2) + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + + for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed): + packed = B[tile_k_packed, tile_n] + lo = ((packed << 4) >> 4).to(torch.int8) + hi = (packed >> 4).to(torch.int8) + lo_bf16 = lo.to(torch.bfloat16) + hi_bf16 = hi.to(torch.bfloat16) + stacked = torch.stack([lo_bf16, hi_bf16], dim=1) + unpacked = stacked.reshape( + tile_k_packed.block_size * 2, tile_n.block_size + ) + + k_begin = tile_k_packed.begin * 2 + k_len = tile_k_packed.block_size * 2 + a_tile = A[tile_m, k_begin : (k_begin + k_len)] + + acc = acc + hl.dot(a_tile, unpacked) + + C[tile_m, tile_n] = acc.to(torch.bfloat16) + + return C + + M, K, N = 16, 32, 16 + A = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) + B_unpacked = torch.randint(-8, 8, (K, N), dtype=torch.int8, device=DEVICE) + B_halves = B_unpacked.reshape(K // 2, 2, N).permute(1, 0, 2) + B_packed = ((B_halves[0] & 0xF) | (B_halves[1] << 4)).to(torch.int8) + + bound = matmul_int4_block_expr.bind((A, B_packed)) + (config,) = matmul_int4_block_expr.configs + code = bound.to_triton_code(config) + self.assertExpectedJournal(code) + + device_code, host_code = code.split("def matmul_int4_block_expr(") + self.assertIn("_BLOCK_SIZE_0 = 1", host_code) + self.assertIn("2 * _BLOCK_SIZE_0, ", host_code) + self.assertIn("[2 * _BLOCK_SIZE_0, ", device_code) + if __name__ == "__main__": unittest.main()