Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sympy
import torch
from torch._inductor.utils import triton_type
from torch._prims_common import compute_required_storage_length

from .. import exc
from .._compat import get_tensor_descriptor_fn_name
Expand Down Expand Up @@ -519,6 +520,27 @@ def compute_shape(
assert len(input_size) == 0, "invalid subscript"
return output_size

@staticmethod
def _needs_int64(fake_value: torch.Tensor) -> bool:
storage_offset = fake_value.storage_offset()
try:
required = compute_required_storage_length(
fake_value.shape,
fake_value.stride(),
storage_offset,
)
except Exception: # pragma: no cover - defensive fallback
return False

if not isinstance(required, int):
return False

if abs(storage_offset) > torch.iinfo(torch.int32).max:
return True

max_offset = required - 1
return max_offset > torch.iinfo(torch.int32).max

@staticmethod
def create(
state: CodegenState,
Expand All @@ -533,6 +555,8 @@ def create(
output_size = SubscriptIndexing.compute_shape(fake_value, index)
env = CompileEnvironment.current()
dtype = env.triton_index_type()
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype)

def _is_size_one(size: int | torch.SymInt) -> bool:
return env.known_equal(size, 1)
Expand Down
7 changes: 7 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ class InvalidIndexingType(BaseError):
message = "Expected tile/int/None/tensor/etc in tensor[...], got {0!s}."


class IndexOffsetOutOfRangeForInt32(BaseError):
message = (
"Tensor indexing offsets exceed the int32 range, but the kernel index_dtype is {0}. "
"Use @helion.kernel(index_dtype=torch.int64) to enable larger offsets."
)


class DataDependentOutputShapeNotSupported(BaseError):
message = (
"{op_desc} is not supported in Helion device loops because it produces "
Expand Down
68 changes: 68 additions & 0 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfLowVRAM
from helion._testing import skipIfNormalMode
from helion._testing import skipIfRefEager
from helion._testing import skipIfRocm
Expand Down Expand Up @@ -241,6 +242,73 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor:
expected = torch.full_like(x, 1, dtype=torch.int32)
torch.testing.assert_close(result, expected)

@skipIfLowVRAM("Test allocates ~15GB across multiple CUDA tensors")
def test_int32_offset_out_of_range_error(self):
repro_config = helion.Config(
block_sizes=[32, 32],
flatten_loops=[False],
indexing="pointer",
l2_groupings=[1],
loop_orders=[[0, 1]],
num_stages=3,
num_warps=4,
pid_type="flat",
range_flattens=[None],
range_multi_buffers=[None],
range_num_stages=[],
range_unroll_factors=[0],
range_warp_specializes=[],
)

def make_kernel(*, index_dtype: torch.dtype):
kwargs = dict(config=repro_config, static_shapes=False)
kwargs["index_dtype"] = index_dtype
decorator = helion.kernel(**kwargs)

@decorator
def repro_bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x, y = torch.broadcast_tensors(x, y)
out = torch.empty(
x.shape,
dtype=torch.promote_types(x.dtype, y.dtype),
device=x.device,
)
for tile in hl.tile(out.size()):
out[tile] = x[tile] + y[tile]
return out

return repro_bf16_add

def run_case(shape, *, index_dtype, expect_int64=False, expect_error=False):
kernel = make_kernel(index_dtype=index_dtype)
x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
torch.cuda.synchronize()
if expect_error:
with self.assertRaisesRegex(
helion.exc.IndexOffsetOutOfRangeForInt32,
f"index_dtype is {index_dtype}",
):
code_and_output(kernel, (x, y))
torch.cuda.synchronize()
return

code, out = code_and_output(kernel, (x, y))
torch.cuda.synchronize()
checker = self.assertIn if expect_int64 else self.assertNotIn
checker("tl.int64", code)
torch.cuda.synchronize()
ref_out = torch.add(x, y)
torch.cuda.synchronize()
torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2)

small_shape = (128, 128)
large_shape = (51200, 51200)

run_case(small_shape, index_dtype=torch.int32)
run_case(large_shape, index_dtype=torch.int32, expect_error=True)
run_case(large_shape, index_dtype=torch.int64, expect_int64=True)

def test_assign_int(self):
@helion.kernel
def fn(x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading