-
Notifications
You must be signed in to change notification settings - Fork 77
Closed
Labels
Description
Repro:
import torch
import helion
import helion.language as hl
@helion.kernel
def empty_kernel(x: torch.Tensor) -> torch.Tensor:
"""
This kernel has no actual computation - it will fail to compile to Triton
because the generated Triton function will have an empty body.
"""
# All computation is commented out or missing
output = torch.zeros_like(x)
for tile in hl.tile(x.size(0)):
a = 1
return output
def main():
print("Testing empty helion kernel...")
x = torch.randn(4, 4, device='cuda', dtype=torch.bfloat16)
# This will fail with IndentationError when Helion tries to compile
# the kernel to Triton code
result = empty_kernel(x)
print("Result:", result)
if __name__ == "__main__":
main()Error log:
$ python minimal_repro_broken.py
Testing empty helion kernel...
[0s] Autotune random seed: 1277959479
Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[4], indexing=[], load_eviction_policies=[], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
Traceback (most recent call last):
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 449, in compile_config
triton_code = self.to_triton_code(
^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 420, in to_triton_code
root = generate_ast(self.host_function, config, emit_repro_caller)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/_compiler/generate_ast.py", line 465, in generate_ast
codegen.add_statement(codegen.visit(stmt))
^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/_compiler/ast_extension.py", line 277, in visit
return visitor(node)
^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/_compiler/generate_ast.py", line 329, in visit_For
raise exc.EmptyDeviceLoopAfterDCE()
helion.exc.EmptyDeviceLoopAfterDCE: Device loop is empty after dead-code elimination. The kernel contains no operations that affect the output.
While processing:
File "/data/users/willfeng/helion/minimal_repro_broken.py", line 22, in empty_kernel
for tile in hl.tile(x.size(0)):
Traceback (most recent call last):
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 177, in _compute_baseline
baseline_output = self.kernel.compile_config(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 449, in compile_config
triton_code = self.to_triton_code(
^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 420, in to_triton_code
root = generate_ast(self.host_function, config, emit_repro_caller)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/_compiler/generate_ast.py", line 465, in generate_ast
codegen.add_statement(codegen.visit(stmt))
^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/_compiler/ast_extension.py", line 277, in visit
return visitor(node)
^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/_compiler/generate_ast.py", line 329, in visit_For
raise exc.EmptyDeviceLoopAfterDCE()
helion.exc.EmptyDeviceLoopAfterDCE: Device loop is empty after dead-code elimination. The kernel contains no operations that affect the output.
While processing:
File "/data/users/willfeng/helion/minimal_repro_broken.py", line 22, in empty_kernel
for tile in hl.tile(x.size(0)):
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/willfeng/helion/minimal_repro_broken.py", line 39, in <module>
main()
File "/data/users/willfeng/helion/minimal_repro_broken.py", line 35, in main
result = empty_kernel(x)
^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 293, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 640, in __call__
self.autotune(args, force=False)
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 520, in autotune
config = self.settings.autotuner_fn(self, args, **kwargs).autotune(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/settings.py", line 226, in default_autotuner_fn
return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/autotuner/pattern_search.py", line 45, in __init__
super().__init__(kernel, args)
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 666, in __init__
super().__init__(kernel, args)
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 125, in __init__
) = self._compute_baseline()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 192, in _compute_baseline
raise exc.InvalidConfig(
helion.exc.InvalidConfig: Default config failed while computing baseline.
Default config: @helion.kernel(config=helion.Config(block_sizes=[4], indexing=[], load_eviction_policies=[], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
Enable HELION_AUTOTUNE_LOG_LEVEL=DEBUG to log generated Triton code.