Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ class _TLS(Protocol):
tls: _TLS = cast("_TLS", threading.local())


def _get_custom_decomp_table() -> dict[torch._ops.OpOverload, Callable[..., object]]:
decomp_table = select_decomp_table().copy()
# Normally, aten.stack is decomposed to aten.unsqueeze + aten.cat, but it's difficult to
# figure out the right Triton implementation for aten.cat. As a workaround, we disable
# the decomp for aten.stack and implement aten.stack in Triton (codegen_stack) instead.
decomp_table.pop(torch.ops.aten.stack.default, None)
return decomp_table


def _make_fx(fn: Callable[..., object], *args: object) -> torch.fx.Graph:
"""
We monkey patch get_proxy_slot to support Tensor/SymInt/SymFloat/SymBool in the
Expand Down Expand Up @@ -628,7 +637,7 @@ def run_subgraph(*args: object) -> list[object]:

with self.disable_tracing() as tracer:
graph = proxy_tensor.make_fx(
run_subgraph, decomposition_table=select_decomp_table()
run_subgraph, decomposition_table=_get_custom_decomp_table()
)(*inputs.get_tensor_args()).graph
graph_idx = self.device_ir.add_graph(
graph,
Expand Down Expand Up @@ -711,7 +720,7 @@ def run_body(*args: object) -> list[object]:

with self.disable_tracing() as tracer:
body_graph = proxy_tensor.make_fx(
run_body, decomposition_table=select_decomp_table()
run_body, decomposition_table=_get_custom_decomp_table()
)(*inputs.get_tensor_args()).graph
assert outputs is not None
graph_idx = self.device_ir.add_graph(
Expand Down
55 changes: 55 additions & 0 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,61 @@ def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
)


@register_lowering(
torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)
def codegen_stack(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
tensors = node.args[0]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)

assert isinstance(tensors, (list, tuple))
tensor_asts = [ctx.env[t] for t in tensors] # pyright: ignore[reportArgumentType]
n = len(tensor_asts)

if n == 0:
raise ValueError("Cannot stack empty tensor list")

# Round up to power of 2 for efficient masking
padded_size = 1 << (n - 1).bit_length()

# Create index array [0, 1, 2, 3, ...] for tensor selection
idx = ctx.cg.device_function.new_var("stack_idx")
ctx.cg.add_statement(statement_from_string(f"{idx} = tl.arange(0, {padded_size})"))

# Broadcast index to target dimension shape
# e.g., dim=0: [:, None, None], dim=1: [None, :, None], dim=2: [None, None, :]
bidx = ctx.cg.device_function.new_var("broadcast_idx")
assert isinstance(dim, int)
pattern = "[" + ", ".join(["None"] * dim + [":"] + ["None"] * max(0, 2 - dim)) + "]"
ctx.cg.add_statement(statement_from_string(f"{bidx} = {idx}{pattern}"))

# Expand each input tensor along the stack dimension
expanded = [ctx.cg.device_function.new_var(f"expanded_{i}") for i in range(n)]
for var, tensor in zip(expanded, tensor_asts, strict=False):
ctx.cg.add_statement(
statement_from_string(f"{var} = tl.expand_dims({{t}}, {dim})", t=tensor)
)

# Initialize result with zeros
result = ctx.cg.device_function.new_var("stacked_result")
ctx.cg.add_statement(
statement_from_string(f"{result} = tl.zeros_like({expanded[0]})")
)

# Select each tensor using masks
for i in range(n):
mask = ctx.cg.device_function.new_var(f"mask_{i}")
ctx.cg.add_statement(statement_from_string(f"{mask} = {bidx} == {i}"))
ctx.cg.add_statement(
statement_from_string(
f"{result} = tl.where({mask}, {expanded[i]}, {result})"
)
)

return expr_from_string(result)


@register_lowering(
torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
Expand Down
90 changes: 90 additions & 0 deletions test/test_views.expected
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,93 @@ def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
_BLOCK_SIZE_1 = 32
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, y, out, out.size(0), out.size(1), x.size(0), x.size(1), y.size(0), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestViews.test_stack_dim0)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_test_stack_dim0_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < 65
indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
mask_2 = indices_3 < 3
for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_2 < 129
a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
stack_idx = tl.arange(0, 4)
broadcast_idx = stack_idx[:, None, None]
expanded_0 = tl.expand_dims(a_tile, 0)
expanded_1 = tl.expand_dims(b_tile, 0)
expanded_2 = tl.expand_dims(c_tile, 0)
stacked_result = tl.zeros_like(expanded_0)
mask_3 = broadcast_idx == 0
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
mask_4 = broadcast_idx == 1
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
mask_5 = broadcast_idx == 2
stacked_result = tl.where(mask_5, expanded_2, stacked_result)
tl.store(result + (indices_3[:, None, None] * 8385 + indices_0[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_2[:, None, None] & mask_0[None, :, None] & mask_1[None, None, :])

def test_stack_dim0_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher):
M, N = a.shape
result = torch.zeros(3, M, N, dtype=a.dtype, device=a.device)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_2 = 4
_BLOCK_SIZE_1 = 32
_launcher(_helion_test_stack_dim0_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return result

--- assertExpectedJournal(TestViews.test_stack_non_power_of_2)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_test_stack_non_power_of_2_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < 65
indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
mask_2 = indices_3 < 3
for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_2 < 129
a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
stack_idx = tl.arange(0, 4)
broadcast_idx = stack_idx[None, :, None]
expanded_0 = tl.expand_dims(a_tile, 1)
expanded_1 = tl.expand_dims(b_tile, 1)
expanded_2 = tl.expand_dims(c_tile, 1)
stacked_result = tl.zeros_like(expanded_0)
mask_3 = broadcast_idx == 0
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
mask_4 = broadcast_idx == 1
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
mask_5 = broadcast_idx == 2
stacked_result = tl.where(mask_5, expanded_2, stacked_result)
tl.store(result + (indices_0[:, None, None] * 387 + indices_3[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :])

def test_stack_non_power_of_2_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher):
M, N = a.shape
result = torch.zeros(M, 3, N, dtype=a.dtype, device=a.device)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_2 = 4
_BLOCK_SIZE_1 = 32
_launcher(_helion_test_stack_non_power_of_2_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return result
122 changes: 122 additions & 0 deletions test/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,128 @@ def fn(x: torch.Tensor) -> torch.Tensor:
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

def test_stack_power_of_2(self):
@helion.kernel(use_default_config=True, static_shapes=True)
def test_stack_power_of_2_kernel(
a: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
M, N = a.shape
result = torch.zeros(M * 2, N, dtype=a.dtype, device=a.device)

for tile_m in hl.tile(M):
for tile_n in hl.tile(N):
a_tile = a[tile_m, tile_n]
b_tile = b[tile_m, tile_n]

# Stack tensors along dim=1 (creates [BLOCK_M, 2, BLOCK_N])
stacked = torch.stack([a_tile, b_tile], dim=1)

# Reshape to [BLOCK_M * 2, BLOCK_N]
reshaped = stacked.reshape(tile_m.block_size * 2, tile_n.block_size)

result[
(tile_m.begin * 2) : (tile_m.begin * 2 + tile_m.block_size * 2),
tile_n,
] = reshaped

return result

M, N = 64, 128
device = DEVICE

a = torch.randn(M, N, dtype=torch.float32, device=device)
b = torch.randn(M, N, dtype=torch.float32, device=device)

result = test_stack_power_of_2_kernel(a, b)
expected = torch.zeros(M * 2, N, dtype=torch.float32, device=device)
expected[0::2] = a # Every 2nd row starting from 0
expected[1::2] = b # Every 2nd row starting from 1
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)

def test_stack_non_power_of_2(self):
@helion.kernel(use_default_config=True, static_shapes=True)
def test_stack_non_power_of_2_kernel(
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
) -> torch.Tensor:
M, N = a.shape
result = torch.zeros(M, 3, N, dtype=a.dtype, device=a.device)

for tile_m in hl.tile(M):
for tile_n in hl.tile(N):
a_tile = a[tile_m, tile_n]
b_tile = b[tile_m, tile_n]
c_tile = c[tile_m, tile_n]

# Stack tensors along dim=1 (creates [BLOCK_M, 3, BLOCK_N])
stacked = torch.stack([a_tile, b_tile, c_tile], dim=1)

result[tile_m, :, tile_n] = stacked

return result

M, N = 65, 129
device = DEVICE

a = torch.randn(M, N, dtype=torch.float32, device=device)
b = torch.randn(M, N, dtype=torch.float32, device=device)
c = torch.randn(M, N, dtype=torch.float32, device=device)

code, result = code_and_output(test_stack_non_power_of_2_kernel, (a, b, c))
expected = torch.stack([a, b, c], dim=1)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
self.assertExpectedJournal(code)

def test_stack_dim0(self):
@helion.kernel(use_default_config=True, static_shapes=True)
def test_stack_dim0_kernel(
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
) -> torch.Tensor:
M, N = a.shape
result = torch.zeros(3, M, N, dtype=a.dtype, device=a.device)

for tile_m in hl.tile(M):
for tile_n in hl.tile(N):
a_tile = a[tile_m, tile_n]
b_tile = b[tile_m, tile_n]
c_tile = c[tile_m, tile_n]

# Stack 3 tensors along dim=0
# This creates [3, BLOCK_M, BLOCK_N]
stacked = torch.stack([a_tile, b_tile, c_tile], dim=0)

result[:, tile_m, tile_n] = stacked

return result

M, N = 65, 129
device = DEVICE

a = torch.randn(M, N, dtype=torch.float32, device=device)
b = torch.randn(M, N, dtype=torch.float32, device=device)
c = torch.randn(M, N, dtype=torch.float32, device=device)

code, result = code_and_output(test_stack_dim0_kernel, (a, b, c))
expected = torch.stack([a, b, c], dim=0)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
self.assertExpectedJournal(code)

# Verify torch.compile still decomposes aten.stack to aten.cat
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test to make sure _get_custom_decomp_table doesn't affect normal torch.compile decomp for torch.stack

from torch._inductor import config as inductor_config

def capture_graph(graph):
self._graph = str(graph)
return graph

with inductor_config.patch(post_grad_custom_pre_pass=capture_graph):
torch.compile(
lambda x, y, z: torch.stack([x, y, z], dim=0), backend="inductor"
)(
torch.randn(4, 4, device=device),
torch.randn(4, 4, device=device),
torch.randn(4, 4, device=device),
)
assert "aten.cat" in self._graph and "aten.stack" not in self._graph


if __name__ == "__main__":
unittest.main()
Loading