diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index a2e0974e0..7f329c391 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -9,6 +9,7 @@ import sympy import torch from torch._inductor.utils import triton_type +from torch._prims_common import compute_required_storage_length from .. import exc from .._compat import get_tensor_descriptor_fn_name @@ -519,6 +520,31 @@ def compute_shape( assert len(input_size) == 0, "invalid subscript" return output_size + @staticmethod + def _needs_int64(fake_value: torch.Tensor) -> bool: + storage_offset = fake_value.storage_offset() + + if not isinstance(storage_offset, int): + return False + + try: + required = compute_required_storage_length( + fake_value.shape, + fake_value.stride(), + storage_offset, + ) + except Exception: + return False + + if not isinstance(required, int): + return False + + if abs(storage_offset) > torch.iinfo(torch.int32).max: + return True + + max_offset = required - 1 + return max_offset > torch.iinfo(torch.int32).max + @staticmethod def create( state: CodegenState, @@ -533,6 +559,8 @@ def create( output_size = SubscriptIndexing.compute_shape(fake_value, index) env = CompileEnvironment.current() dtype = env.triton_index_type() + if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value): + raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype) def _is_size_one(size: int | torch.SymInt) -> bool: return env.known_equal(size, 1) diff --git a/helion/exc.py b/helion/exc.py index f7deeb773..f9a231996 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -107,6 +107,13 @@ class InvalidIndexingType(BaseError): message = "Expected tile/int/None/tensor/etc in tensor[...], got {0!s}." +class IndexOffsetOutOfRangeForInt32(BaseError): + message = ( + "Kernel index_dtype is {0}, but tensor indexing offsets exceed the int32 range. " + "Use @helion.kernel(index_dtype=torch.int64) to enable larger offsets." + ) + + class DataDependentOutputShapeNotSupported(BaseError): message = ( "{op_desc} is not supported in Helion device loops because it produces " diff --git a/test/test_indexing.py b/test/test_indexing.py index 98607e772..215ba4c0b 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -11,6 +11,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfLowVRAM from helion._testing import skipIfNormalMode from helion._testing import skipIfRefEager from helion._testing import skipIfRocm @@ -241,6 +242,93 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor: expected = torch.full_like(x, 1, dtype=torch.int32) torch.testing.assert_close(result, expected) + @skipIfRefEager( + "IndexOffsetOutOfRangeForInt32 error is not raised in ref eager mode" + ) + @skipIfLowVRAM("Test requires high VRAM") + def test_int32_offset_out_of_range_error(self): + repro_config = helion.Config( + block_sizes=[32, 32], + flatten_loops=[False], + indexing="pointer", + l2_groupings=[1], + loop_orders=[[0, 1]], + num_stages=3, + num_warps=4, + pid_type="flat", + range_flattens=[None], + range_multi_buffers=[None], + range_num_stages=[], + range_unroll_factors=[0], + range_warp_specializes=[], + ) + + def make_kernel(*, index_dtype: torch.dtype): + kwargs = {"config": repro_config, "static_shapes": True} + kwargs["index_dtype"] = index_dtype + decorator = helion.kernel(**kwargs) + + @decorator + def repro_bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x, y = torch.broadcast_tensors(x, y) + out = torch.empty( + x.shape, + dtype=torch.promote_types(x.dtype, y.dtype), + device=x.device, + ) + for tile in hl.tile(out.size()): + out[tile] = x[tile] + y[tile] + return out + + return repro_bf16_add + + def run_case( + shape, *, index_dtype, expect_int64_in_code=False, expect_error=False + ): + kernel = make_kernel(index_dtype=index_dtype) + x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16) + y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16) + torch.cuda.synchronize() + if expect_error: + with self.assertRaisesRegex( + helion.exc.IndexOffsetOutOfRangeForInt32, + f"index_dtype is {index_dtype}", + ): + code_and_output(kernel, (x, y)) + torch.cuda.synchronize() + return + + code, out = code_and_output(kernel, (x, y)) + torch.cuda.synchronize() + checker = self.assertIn if expect_int64_in_code else self.assertNotIn + checker("tl.int64", code) + torch.cuda.synchronize() + ref_out = torch.add(x, y) + torch.cuda.synchronize() + torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2) + + small_shape = (128, 128) + large_shape = (51200, 51200) + + run_case( + small_shape, + index_dtype=torch.int32, + expect_int64_in_code=False, + expect_error=False, + ) + run_case( + large_shape, + index_dtype=torch.int32, + expect_int64_in_code=False, + expect_error=True, + ) + run_case( + large_shape, + index_dtype=torch.int64, + expect_int64_in_code=True, + expect_error=False, + ) + def test_assign_int(self): @helion.kernel def fn(x: torch.Tensor) -> torch.Tensor: