From b317f035b42f5fed90d9c467b8ff789e8241b58a Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 8 Oct 2025 14:38:30 -0700 Subject: [PATCH] wip --- helion/_compiler/indexing_strategy.py | 24 ++++++++++ helion/exc.py | 7 +++ test/test_indexing.py | 68 +++++++++++++++++++++++++++ 3 files changed, 99 insertions(+) diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index a2e0974e0..eaaa0e77e 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -9,6 +9,7 @@ import sympy import torch from torch._inductor.utils import triton_type +from torch._prims_common import compute_required_storage_length from .. import exc from .._compat import get_tensor_descriptor_fn_name @@ -519,6 +520,27 @@ def compute_shape( assert len(input_size) == 0, "invalid subscript" return output_size + @staticmethod + def _needs_int64(fake_value: torch.Tensor) -> bool: + storage_offset = fake_value.storage_offset() + try: + required = compute_required_storage_length( + fake_value.shape, + fake_value.stride(), + storage_offset, + ) + except Exception: # pragma: no cover - defensive fallback + return False + + if not isinstance(required, int): + return False + + if abs(storage_offset) > torch.iinfo(torch.int32).max: + return True + + max_offset = required - 1 + return max_offset > torch.iinfo(torch.int32).max + @staticmethod def create( state: CodegenState, @@ -533,6 +555,8 @@ def create( output_size = SubscriptIndexing.compute_shape(fake_value, index) env = CompileEnvironment.current() dtype = env.triton_index_type() + if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value): + raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype) def _is_size_one(size: int | torch.SymInt) -> bool: return env.known_equal(size, 1) diff --git a/helion/exc.py b/helion/exc.py index 12d284476..c951c82a8 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -107,6 +107,13 @@ class InvalidIndexingType(BaseError): message = "Expected tile/int/None/tensor/etc in tensor[...], got {0!s}." +class IndexOffsetOutOfRangeForInt32(BaseError): + message = ( + "Tensor indexing offsets exceed the int32 range, but the kernel index_dtype is {0}. " + "Use @helion.kernel(index_dtype=torch.int64) to enable larger offsets." + ) + + class DataDependentOutputShapeNotSupported(BaseError): message = ( "{op_desc} is not supported in Helion device loops because it produces " diff --git a/test/test_indexing.py b/test/test_indexing.py index 98607e772..7177f428c 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -11,6 +11,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfLowVRAM from helion._testing import skipIfNormalMode from helion._testing import skipIfRefEager from helion._testing import skipIfRocm @@ -241,6 +242,73 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor: expected = torch.full_like(x, 1, dtype=torch.int32) torch.testing.assert_close(result, expected) + @skipIfLowVRAM("Test allocates ~15GB across multiple CUDA tensors") + def test_int32_offset_out_of_range_error(self): + repro_config = helion.Config( + block_sizes=[32, 32], + flatten_loops=[False], + indexing="pointer", + l2_groupings=[1], + loop_orders=[[0, 1]], + num_stages=3, + num_warps=4, + pid_type="flat", + range_flattens=[None], + range_multi_buffers=[None], + range_num_stages=[], + range_unroll_factors=[0], + range_warp_specializes=[], + ) + + def make_kernel(*, index_dtype: torch.dtype): + kwargs = dict(config=repro_config, static_shapes=False) + kwargs["index_dtype"] = index_dtype + decorator = helion.kernel(**kwargs) + + @decorator + def repro_bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x, y = torch.broadcast_tensors(x, y) + out = torch.empty( + x.shape, + dtype=torch.promote_types(x.dtype, y.dtype), + device=x.device, + ) + for tile in hl.tile(out.size()): + out[tile] = x[tile] + y[tile] + return out + + return repro_bf16_add + + def run_case(shape, *, index_dtype, expect_int64=False, expect_error=False): + kernel = make_kernel(index_dtype=index_dtype) + x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16) + y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16) + torch.cuda.synchronize() + if expect_error: + with self.assertRaisesRegex( + helion.exc.IndexOffsetOutOfRangeForInt32, + f"index_dtype is {index_dtype}", + ): + code_and_output(kernel, (x, y)) + torch.cuda.synchronize() + return + + code, out = code_and_output(kernel, (x, y)) + torch.cuda.synchronize() + checker = self.assertIn if expect_int64 else self.assertNotIn + checker("tl.int64", code) + torch.cuda.synchronize() + ref_out = torch.add(x, y) + torch.cuda.synchronize() + torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2) + + small_shape = (128, 128) + large_shape = (51200, 51200) + + run_case(small_shape, index_dtype=torch.int32) + run_case(large_shape, index_dtype=torch.int32, expect_error=True) + run_case(large_shape, index_dtype=torch.int64, expect_int64=True) + def test_assign_int(self): @helion.kernel def fn(x: torch.Tensor) -> torch.Tensor: