Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ These exceptions occur when Helion language functions are used incorrectly with

Raised when tuple unpacking fails for single tile.

.. autoclass:: InvalidTileRange

Raised when ``hl.tile`` is given a range where the begin exceeds the end.

.. autoclass:: OverpackedTile

Raised when tile is wrapped in container when indexing.
Expand Down
7 changes: 7 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ class OverpackedTile(BaseError):
)


class InvalidTileRange(BaseError):
message = (
"hl.tile() expects the begin of the range to be less than or equal to the end. "
"Got begin={0!s}, end={1!s}."
)


class AssignmentMultipleTargets(NotAllowedOnDevice):
message = "Assignment with multiple targets (a=b=1) is not allowed inside the `hl.tile` or `hl.grid` loop."

Expand Down
7 changes: 7 additions & 0 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ def _(
)
block_size_list = Tile._tiles_to_sizes(block_size_list)

if unpack:
target = getattr(parent, "target", None)
if isinstance(target, (ast.Tuple, ast.List)) and len(target.elts) > 1:
raise exc.FailedToUnpackTile from None

results = []
for begin_part, end_part, bs in zip(
begin_list,
Expand All @@ -339,6 +344,8 @@ def _(
if isinstance(begin_part, Tile) or isinstance(end_part, Tile):
raise exc.TileOfTile
size = end_part - begin_part # type: ignore[operator]
if isinstance(size, int) and size < 0:
raise exc.InvalidTileRange(begin_part, end_part)
if isinstance(size, torch.Tensor):
size = None # data dependent size
if bs is None:
Expand Down
28 changes: 28 additions & 0 deletions test/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,34 @@ def fn(x: torch.Tensor) -> torch.Tensor:
with self.assertRaises(helion.exc.OverpackedTile):
code_and_output(fn, (torch.randn(100, 100, device=DEVICE),))

def test_tile_invalid_range_unpack(self):
@helion.kernel()
def fn(x: torch.Tensor) -> torch.Tensor:
m = x.size(0)
m = hl.specialize(m)
d = x.size(2)
for _tile_m, _tile_d in hl.tile(m, d):
pass
return x

with self.assertRaises(helion.exc.FailedToUnpackTile):
code_and_output(fn, (torch.randn(192, 4, 128, device=DEVICE),))

def test_tile_invalid_range_single_dim(self):
@helion.kernel()
def fn(x: torch.Tensor) -> torch.Tensor:
start = hl.specialize(x.size(0))
end = x.size(2)
for _tile_m in hl.tile(start, end):
pass
return x

with self.assertRaisesRegex(
helion.exc.InvalidTileRange,
r"begin=192, end=128",
):
code_and_output(fn, (torch.randn(192, 4, 128, device=DEVICE),))

def test_invalid_config_insufficient_block_sizes(self):
"""Test that InvalidConfig shows helpful message for missing block sizes."""

Expand Down
Loading