diff --git a/helion/_compiler/roll_reduction.py b/helion/_compiler/roll_reduction.py index 3a529aeff..6e500218f 100644 --- a/helion/_compiler/roll_reduction.py +++ b/helion/_compiler/roll_reduction.py @@ -11,6 +11,7 @@ from ..language._tracing_ops import _get_symnode from ..language._tracing_ops import _host_tensor from ..language._tracing_ops import _if +from ..language.memory_ops import atomic_add from ..language.memory_ops import store from ..language.reduce_ops import _reduce from .compile_environment import CompileEnvironment @@ -107,6 +108,13 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool: else: # For non-Node values (scalars), they don't have metadata val = stored_value + elif node.target is atomic_add: + # atomic_add(target, index, value, sem) + _, _, value, *_ = node.args + if isinstance(value, torch.fx.Node): + val = value.meta["val"] + else: + val = value else: val = node.meta["val"] diff --git a/test/test_atomic_add.expected b/test/test_atomic_add.expected index bc69cdc46..734b1f37f 100644 --- a/test/test_atomic_add.expected +++ b/test/test_atomic_add.expected @@ -30,6 +30,38 @@ def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default _launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return x +--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_1d_tensor) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_atomic_add_1d_tensor_kernel(x, y, z, x_stride_0, x_stride_1, y_stride_0, y_stride_1, z_stride_0, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + x_tile = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None], other=0) + y_tile = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None], other=0) + v_0 = x_tile * y_tile + z_vec = tl.cast(tl.sum(v_0, 0), tl.float32) + iota = tl.arange(0, 64) + tl.atomic_add(z + iota * z_stride_0, z_vec, mask=None, sem='relaxed') + +def atomic_add_1d_tensor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + """Test atomic_add where the index is a 1D tensor""" + m, n = x.shape + n = 64 + z = torch.zeros([n], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = 64 + _launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, z, x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return z + --- assertExpectedJournal(TestAtomicOperations.test_atomic_add_float) from __future__ import annotations diff --git a/test/test_atomic_add.py b/test/test_atomic_add.py index 1b9d93e61..66b170d0b 100644 --- a/test/test_atomic_add.py +++ b/test/test_atomic_add.py @@ -58,6 +58,25 @@ def atomic_add_w_tile_attr(x: torch.Tensor) -> torch.Tensor: return y +@helion.kernel() +def atomic_add_1d_tensor_kernel( + x: torch.Tensor, y: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Test atomic_add where the index is a 1D tensor""" + m, n = x.shape + n = hl.specialize(n) + + z = torch.zeros([n], dtype=x.dtype, device=x.device) + + for tile_m in hl.tile(m): + x_tile = x[tile_m, :].to(torch.float32) + y_tile = y[tile_m, :].to(torch.float32) + z_vec = torch.sum(x_tile * y_tile, dim=0).to(x.dtype) + hl.atomic_add(z, [hl.arange(0, n)], z_vec) + + return z + + class TestAtomicOperations(RefEagerTestBase, TestCase): def test_basic_atomic_add(self): x = torch.zeros(10, device=DEVICE) @@ -74,6 +93,22 @@ def test_basic_atomic_add(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + def test_atomic_add_1d_tensor(self): + M, N = 32, 64 + x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) + y = torch.randn(M, N, device=DEVICE, dtype=torch.float32) + args = (x, y) + + code, result = code_and_output( + atomic_add_1d_tensor_kernel, + args, + block_sizes=[32], + ) + + expected = (x * y).sum(dim=0) + torch.testing.assert_close(result, expected) + self.assertExpectedJournal(code) + def test_overlapping_atomic_add(self): # Test with overlapping indices x = torch.zeros(5, device=DEVICE)