diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md index c7d18b285..6533c1065 100644 --- a/docs/api/exceptions.md +++ b/docs/api/exceptions.md @@ -115,6 +115,10 @@ These exceptions occur when Helion language functions are used incorrectly with Raised when tuple unpacking fails for single tile. +.. autoclass:: InvalidTileRange + + Raised when ``hl.tile`` is given a range where the begin exceeds the end. + .. autoclass:: OverpackedTile Raised when tile is wrapped in container when indexing. diff --git a/helion/exc.py b/helion/exc.py index 878bbd562..510ec5014 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -223,6 +223,13 @@ class OverpackedTile(BaseError): ) +class InvalidTileRange(BaseError): + message = ( + "hl.tile() expects the begin of the range to be less than or equal to the end. " + "Got begin={0!s}, end={1!s}." + ) + + class AssignmentMultipleTargets(NotAllowedOnDevice): message = "Assignment with multiple targets (a=b=1) is not allowed inside the `hl.tile` or `hl.grid` loop." diff --git a/helion/language/loops.py b/helion/language/loops.py index f9f1a540c..304c6cfb7 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -329,6 +329,11 @@ def _( ) block_size_list = Tile._tiles_to_sizes(block_size_list) + if unpack: + target = getattr(parent, "target", None) + if isinstance(target, (ast.Tuple, ast.List)) and len(target.elts) > 1: + raise exc.FailedToUnpackTile from None + results = [] for begin_part, end_part, bs in zip( begin_list, @@ -339,6 +344,8 @@ def _( if isinstance(begin_part, Tile) or isinstance(end_part, Tile): raise exc.TileOfTile size = end_part - begin_part # type: ignore[operator] + if isinstance(size, int) and size < 0: + raise exc.InvalidTileRange(begin_part, end_part) if isinstance(size, torch.Tensor): size = None # data dependent size if bs is None: diff --git a/test/test_errors.py b/test/test_errors.py index 7d648a8d3..d7d4fdebe 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -135,6 +135,34 @@ def fn(x: torch.Tensor) -> torch.Tensor: with self.assertRaises(helion.exc.OverpackedTile): code_and_output(fn, (torch.randn(100, 100, device=DEVICE),)) + def test_tile_invalid_range_unpack(self): + @helion.kernel() + def fn(x: torch.Tensor) -> torch.Tensor: + m = x.size(0) + m = hl.specialize(m) + d = x.size(2) + for _tile_m, _tile_d in hl.tile(m, d): + pass + return x + + with self.assertRaises(helion.exc.FailedToUnpackTile): + code_and_output(fn, (torch.randn(192, 4, 128, device=DEVICE),)) + + def test_tile_invalid_range_single_dim(self): + @helion.kernel() + def fn(x: torch.Tensor) -> torch.Tensor: + start = hl.specialize(x.size(0)) + end = x.size(2) + for _tile_m in hl.tile(start, end): + pass + return x + + with self.assertRaisesRegex( + helion.exc.InvalidTileRange, + r"begin=192, end=128", + ): + code_and_output(fn, (torch.randn(192, 4, 128, device=DEVICE),)) + def test_invalid_config_insufficient_block_sizes(self): """Test that InvalidConfig shows helpful message for missing block sizes."""