Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
if persistent_body is not None:
self.device_function.body = persistent_body # pyright: ignore[reportAttributeAccessIssue]
self.device_function.dead_code_elimination()
if not self.device_function.preamble and not self.device_function.body:
raise exc.EmptyDeviceLoopAfterDCE
return self.device_function.codegen_function_call()
return None
return self.generic_visit(node)
Expand Down
7 changes: 7 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,10 @@ class NestedKernelCallsNotSupported(BaseError):
"If you need to share code between kernels, consider extracting the shared logic "
"into a regular Python function that can be called from within both kernels."
)


class EmptyDeviceLoopAfterDCE(BaseError):
message = (
"Device loop is empty after dead-code elimination. "
"The kernel contains no operations that affect the output."
)
16 changes: 16 additions & 0 deletions test/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,22 @@ def kernel_with_dot_mismatch(
):
code_and_output(kernel_with_dot_mismatch, (q, k))

def test_empty_device_loop_after_dce(self):
@helion.kernel()
def empty_kernel(x: torch.Tensor) -> torch.Tensor:
# All computation is dead code
output = torch.zeros_like(x)
for _tile in hl.tile(x.size(0)):
# Do nothing that affects the output
_a = 1
return output

with self.assertRaisesRegex(
helion.exc.EmptyDeviceLoopAfterDCE,
r"Device loop is empty after dead-code elimination",
):
code_and_output(empty_kernel, (torch.randn(4, 4, device=DEVICE),))


if __name__ == "__main__":
unittest.main()
Loading