Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions helion/_compiler/roll_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..language._tracing_ops import _get_symnode
from ..language._tracing_ops import _host_tensor
from ..language._tracing_ops import _if
from ..language.memory_ops import atomic_add
from ..language.memory_ops import store
from ..language.reduce_ops import _reduce
from .compile_environment import CompileEnvironment
Expand Down Expand Up @@ -107,6 +108,13 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
else:
# For non-Node values (scalars), they don't have metadata
val = stored_value
elif node.target is atomic_add:
# atomic_add(target, index, value, sem)
_, _, value, *_ = node.args
if isinstance(value, torch.fx.Node):
val = value.meta["val"]
else:
val = value
Comment on lines +111 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case could be merged with the one above, but I'll merge this since I'm working on a PR that will conflict.

else:
val = node.meta["val"]

Expand Down
32 changes: 32 additions & 0 deletions test/test_atomic_add.expected
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,38 @@ def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default
_launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return x

--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_1d_tensor)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_atomic_add_1d_tensor_kernel(x, y, z, x_stride_0, x_stride_1, y_stride_0, y_stride_1, z_stride_0, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
x_tile = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None], other=0)
y_tile = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None], other=0)
v_0 = x_tile * y_tile
z_vec = tl.cast(tl.sum(v_0, 0), tl.float32)
iota = tl.arange(0, 64)
tl.atomic_add(z + iota * z_stride_0, z_vec, mask=None, sem='relaxed')

def atomic_add_1d_tensor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
"""Test atomic_add where the index is a 1D tensor"""
m, n = x.shape
n = 64
z = torch.zeros([n], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = 64
_launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, z, x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return z

--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_float)
from __future__ import annotations

Expand Down
35 changes: 35 additions & 0 deletions test/test_atomic_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ def atomic_add_w_tile_attr(x: torch.Tensor) -> torch.Tensor:
return y


@helion.kernel()
def atomic_add_1d_tensor_kernel(
x: torch.Tensor, y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Test atomic_add where the index is a 1D tensor"""
m, n = x.shape
n = hl.specialize(n)

z = torch.zeros([n], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
x_tile = x[tile_m, :].to(torch.float32)
y_tile = y[tile_m, :].to(torch.float32)
z_vec = torch.sum(x_tile * y_tile, dim=0).to(x.dtype)
hl.atomic_add(z, [hl.arange(0, n)], z_vec)

return z


class TestAtomicOperations(RefEagerTestBase, TestCase):
def test_basic_atomic_add(self):
x = torch.zeros(10, device=DEVICE)
Expand All @@ -74,6 +93,22 @@ def test_basic_atomic_add(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

def test_atomic_add_1d_tensor(self):
M, N = 32, 64
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
y = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
args = (x, y)

code, result = code_and_output(
atomic_add_1d_tensor_kernel,
args,
block_sizes=[32],
)

expected = (x * y).sum(dim=0)
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

def test_overlapping_atomic_add(self):
# Test with overlapping indices
x = torch.zeros(5, device=DEVICE)
Expand Down
Loading