From 71cfb789951fe5982fdfad9469ad38315dca1797 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 30 Aug 2025 15:58:19 -0700 Subject: [PATCH] torch.stack support --- helion/_compiler/device_ir.py | 13 ++- helion/_compiler/inductor_lowering.py | 55 ++++++++++++ test/test_views.expected | 90 +++++++++++++++++++ test/test_views.py | 122 ++++++++++++++++++++++++++ 4 files changed, 278 insertions(+), 2 deletions(-) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 6f677694f..91539854b 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -73,6 +73,15 @@ class _TLS(Protocol): tls: _TLS = cast("_TLS", threading.local()) +def _get_custom_decomp_table() -> dict[torch._ops.OpOverload, Callable[..., object]]: + decomp_table = select_decomp_table().copy() + # Normally, aten.stack is decomposed to aten.unsqueeze + aten.cat, but it's difficult to + # figure out the right Triton implementation for aten.cat. As a workaround, we disable + # the decomp for aten.stack and implement aten.stack in Triton (codegen_stack) instead. + decomp_table.pop(torch.ops.aten.stack.default, None) + return decomp_table + + def _make_fx(fn: Callable[..., object], *args: object) -> torch.fx.Graph: """ We monkey patch get_proxy_slot to support Tensor/SymInt/SymFloat/SymBool in the @@ -628,7 +637,7 @@ def run_subgraph(*args: object) -> list[object]: with self.disable_tracing() as tracer: graph = proxy_tensor.make_fx( - run_subgraph, decomposition_table=select_decomp_table() + run_subgraph, decomposition_table=_get_custom_decomp_table() )(*inputs.get_tensor_args()).graph graph_idx = self.device_ir.add_graph( graph, @@ -711,7 +720,7 @@ def run_body(*args: object) -> list[object]: with self.disable_tracing() as tracer: body_graph = proxy_tensor.make_fx( - run_body, decomposition_table=select_decomp_table() + run_body, decomposition_table=_get_custom_decomp_table() )(*inputs.get_tensor_args()).graph assert outputs is not None graph_idx = self.device_ir.add_graph( diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index e4c8c32c7..94d2af3be 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -836,6 +836,61 @@ def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object: ) +@register_lowering( + torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +def codegen_stack(ctx: GraphInterpreter, node: torch.fx.Node) -> object: + tensors = node.args[0] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + + assert isinstance(tensors, (list, tuple)) + tensor_asts = [ctx.env[t] for t in tensors] # pyright: ignore[reportArgumentType] + n = len(tensor_asts) + + if n == 0: + raise ValueError("Cannot stack empty tensor list") + + # Round up to power of 2 for efficient masking + padded_size = 1 << (n - 1).bit_length() + + # Create index array [0, 1, 2, 3, ...] for tensor selection + idx = ctx.cg.device_function.new_var("stack_idx") + ctx.cg.add_statement(statement_from_string(f"{idx} = tl.arange(0, {padded_size})")) + + # Broadcast index to target dimension shape + # e.g., dim=0: [:, None, None], dim=1: [None, :, None], dim=2: [None, None, :] + bidx = ctx.cg.device_function.new_var("broadcast_idx") + assert isinstance(dim, int) + pattern = "[" + ", ".join(["None"] * dim + [":"] + ["None"] * max(0, 2 - dim)) + "]" + ctx.cg.add_statement(statement_from_string(f"{bidx} = {idx}{pattern}")) + + # Expand each input tensor along the stack dimension + expanded = [ctx.cg.device_function.new_var(f"expanded_{i}") for i in range(n)] + for var, tensor in zip(expanded, tensor_asts, strict=False): + ctx.cg.add_statement( + statement_from_string(f"{var} = tl.expand_dims({{t}}, {dim})", t=tensor) + ) + + # Initialize result with zeros + result = ctx.cg.device_function.new_var("stacked_result") + ctx.cg.add_statement( + statement_from_string(f"{result} = tl.zeros_like({expanded[0]})") + ) + + # Select each tensor using masks + for i in range(n): + mask = ctx.cg.device_function.new_var(f"mask_{i}") + ctx.cg.add_statement(statement_from_string(f"{mask} = {bidx} == {i}")) + ctx.cg.add_statement( + statement_from_string( + f"{result} = tl.where({mask}, {expanded[i]}, {result})" + ) + ) + + return expr_from_string(result) + + @register_lowering( torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=passthrough_masked_value, diff --git a/test/test_views.expected b/test/test_views.expected index 0429d4cc6..fae9f3f10 100644 --- a/test/test_views.expected +++ b/test/test_views.expected @@ -140,3 +140,93 @@ def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 32 _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, y, out, out.size(0), out.size(1), x.size(0), x.size(1), y.size(0), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return out + +--- assertExpectedJournal(TestViews.test_stack_dim0) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test_stack_dim0_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 65 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 3 + for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < 129 + a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0) + b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0) + c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0) + stack_idx = tl.arange(0, 4) + broadcast_idx = stack_idx[:, None, None] + expanded_0 = tl.expand_dims(a_tile, 0) + expanded_1 = tl.expand_dims(b_tile, 0) + expanded_2 = tl.expand_dims(c_tile, 0) + stacked_result = tl.zeros_like(expanded_0) + mask_3 = broadcast_idx == 0 + stacked_result = tl.where(mask_3, expanded_0, stacked_result) + mask_4 = broadcast_idx == 1 + stacked_result = tl.where(mask_4, expanded_1, stacked_result) + mask_5 = broadcast_idx == 2 + stacked_result = tl.where(mask_5, expanded_2, stacked_result) + tl.store(result + (indices_3[:, None, None] * 8385 + indices_0[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_2[:, None, None] & mask_0[None, :, None] & mask_1[None, None, :]) + +def test_stack_dim0_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher): + M, N = a.shape + result = torch.zeros(3, M, N, dtype=a.dtype, device=a.device) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 4 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_test_stack_dim0_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestViews.test_stack_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test_stack_non_power_of_2_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 65 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 3 + for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < 129 + a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0) + b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0) + c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0) + stack_idx = tl.arange(0, 4) + broadcast_idx = stack_idx[None, :, None] + expanded_0 = tl.expand_dims(a_tile, 1) + expanded_1 = tl.expand_dims(b_tile, 1) + expanded_2 = tl.expand_dims(c_tile, 1) + stacked_result = tl.zeros_like(expanded_0) + mask_3 = broadcast_idx == 0 + stacked_result = tl.where(mask_3, expanded_0, stacked_result) + mask_4 = broadcast_idx == 1 + stacked_result = tl.where(mask_4, expanded_1, stacked_result) + mask_5 = broadcast_idx == 2 + stacked_result = tl.where(mask_5, expanded_2, stacked_result) + tl.store(result + (indices_0[:, None, None] * 387 + indices_3[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :]) + +def test_stack_non_power_of_2_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher): + M, N = a.shape + result = torch.zeros(M, 3, N, dtype=a.dtype, device=a.device) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 4 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_test_stack_non_power_of_2_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return result diff --git a/test/test_views.py b/test/test_views.py index 455e867ab..b60357787 100644 --- a/test/test_views.py +++ b/test/test_views.py @@ -209,6 +209,128 @@ def fn(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + def test_stack_power_of_2(self): + @helion.kernel(use_default_config=True, static_shapes=True) + def test_stack_power_of_2_kernel( + a: torch.Tensor, b: torch.Tensor + ) -> torch.Tensor: + M, N = a.shape + result = torch.zeros(M * 2, N, dtype=a.dtype, device=a.device) + + for tile_m in hl.tile(M): + for tile_n in hl.tile(N): + a_tile = a[tile_m, tile_n] + b_tile = b[tile_m, tile_n] + + # Stack tensors along dim=1 (creates [BLOCK_M, 2, BLOCK_N]) + stacked = torch.stack([a_tile, b_tile], dim=1) + + # Reshape to [BLOCK_M * 2, BLOCK_N] + reshaped = stacked.reshape(tile_m.block_size * 2, tile_n.block_size) + + result[ + (tile_m.begin * 2) : (tile_m.begin * 2 + tile_m.block_size * 2), + tile_n, + ] = reshaped + + return result + + M, N = 64, 128 + device = DEVICE + + a = torch.randn(M, N, dtype=torch.float32, device=device) + b = torch.randn(M, N, dtype=torch.float32, device=device) + + result = test_stack_power_of_2_kernel(a, b) + expected = torch.zeros(M * 2, N, dtype=torch.float32, device=device) + expected[0::2] = a # Every 2nd row starting from 0 + expected[1::2] = b # Every 2nd row starting from 1 + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + def test_stack_non_power_of_2(self): + @helion.kernel(use_default_config=True, static_shapes=True) + def test_stack_non_power_of_2_kernel( + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + M, N = a.shape + result = torch.zeros(M, 3, N, dtype=a.dtype, device=a.device) + + for tile_m in hl.tile(M): + for tile_n in hl.tile(N): + a_tile = a[tile_m, tile_n] + b_tile = b[tile_m, tile_n] + c_tile = c[tile_m, tile_n] + + # Stack tensors along dim=1 (creates [BLOCK_M, 3, BLOCK_N]) + stacked = torch.stack([a_tile, b_tile, c_tile], dim=1) + + result[tile_m, :, tile_n] = stacked + + return result + + M, N = 65, 129 + device = DEVICE + + a = torch.randn(M, N, dtype=torch.float32, device=device) + b = torch.randn(M, N, dtype=torch.float32, device=device) + c = torch.randn(M, N, dtype=torch.float32, device=device) + + code, result = code_and_output(test_stack_non_power_of_2_kernel, (a, b, c)) + expected = torch.stack([a, b, c], dim=1) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + def test_stack_dim0(self): + @helion.kernel(use_default_config=True, static_shapes=True) + def test_stack_dim0_kernel( + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + M, N = a.shape + result = torch.zeros(3, M, N, dtype=a.dtype, device=a.device) + + for tile_m in hl.tile(M): + for tile_n in hl.tile(N): + a_tile = a[tile_m, tile_n] + b_tile = b[tile_m, tile_n] + c_tile = c[tile_m, tile_n] + + # Stack 3 tensors along dim=0 + # This creates [3, BLOCK_M, BLOCK_N] + stacked = torch.stack([a_tile, b_tile, c_tile], dim=0) + + result[:, tile_m, tile_n] = stacked + + return result + + M, N = 65, 129 + device = DEVICE + + a = torch.randn(M, N, dtype=torch.float32, device=device) + b = torch.randn(M, N, dtype=torch.float32, device=device) + c = torch.randn(M, N, dtype=torch.float32, device=device) + + code, result = code_and_output(test_stack_dim0_kernel, (a, b, c)) + expected = torch.stack([a, b, c], dim=0) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + # Verify torch.compile still decomposes aten.stack to aten.cat + from torch._inductor import config as inductor_config + + def capture_graph(graph): + self._graph = str(graph) + return graph + + with inductor_config.patch(post_grad_custom_pre_pass=capture_graph): + torch.compile( + lambda x, y, z: torch.stack([x, y, z], dim=0), backend="inductor" + )( + torch.randn(4, 4, device=device), + torch.randn(4, 4, device=device), + torch.randn(4, 4, device=device), + ) + assert "aten.cat" in self._graph and "aten.stack" not in self._graph + if __name__ == "__main__": unittest.main()