-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
Describe the bug
Getting the following error when using HELION_PRINT_OUTPUT_CODE=1
Testing helion correctness...
Using default config: @helion.kernel(config=helion.Config(block_sizes=[32, 32], indexing='pointer', num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=False)
Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[32, 32], indexing='pointer', num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=False)
Traceback (most recent call last):
File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 434, in compile_config
triton_code = self.to_triton_code(
^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 409, in to_triton_code
root = generate_ast(self.host_function, config, emit_repro_caller)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/_compiler/generate_ast.py", line 424, in generate_ast
call_def = [func.codegen_call_function()]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/_compiler/host_function.py", line 282, in codegen_call_function
inits.append(statement_from_string(f"{name} = {rhs}"))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/_compiler/ast_extension.py", line 206, in statement_from_string
(statement,) = ast.parse(modified_template).body
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/py312-env/lib/python3.12/ast.py", line 53, in parse
return compile(source, filename, mode, flags,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<unknown>", line 1
grad_out = rand_strided(size=(,), stride=(,), dtype=torch.float32, device='cuda:0')
^
SyntaxError: invalid syntax
Traceback (most recent call last):
File "/data/users/williamwen/helion/examples/kl_div.py", line 371, in <module>
main()
File "/data/users/williamwen/helion/examples/kl_div.py", line 365, in main
check_kl_div_kernel(B, T, V, reduction, log_target, eps)
File "/data/users/williamwen/helion/examples/kl_div.py", line 316, in check_kl_div_kernel
run_example(helion_wrapper, baseline_wrapper, create_inputs(), bwd=True)
File "/data/users/williamwen/helion/helion/_testing.py", line 519, in run_example
out.backward(grad_output, retain_graph=True)
File "/data/users/williamwen/pytorch/torch/_tensor.py", line 625, in backward
torch.autograd.backward(
File "/data/users/williamwen/pytorch/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/data/users/williamwen/pytorch/torch/autograd/graph.py", line 841, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/autograd/function.py", line 315, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/examples/kl_div.py", line 209, in backward
grad_y_pred, grad_y_true = kl_div_backward(
^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 286, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 618, in __call__
self.set_config(config)
File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 523, in set_config
self._run = self.compile_config(config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 434, in compile_config
triton_code = self.to_triton_code(
^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 409, in to_triton_code
root = generate_ast(self.host_function, config, emit_repro_caller)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/_compiler/generate_ast.py", line 424, in generate_ast
call_def = [func.codegen_call_function()]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/_compiler/host_function.py", line 282, in codegen_call_function
inits.append(statement_from_string(f"{name} = {rhs}"))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/helion/helion/_compiler/ast_extension.py", line 206, in statement_from_string
(statement,) = ast.parse(modified_template).body
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/py312-env/lib/python3.12/ast.py", line 53, in parse
return compile(source, filename, mode, flags,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<unknown>", line 1
grad_out = rand_strided(size=(,), stride=(,), dtype=torch.float32, device='cuda:0')
^
SyntaxError: invalid syntax
To Reproduce
Checkout #802, run HELION_PRINT_OUTPUT_CODE=1 HELION_USE_DEFAULT_CONFIG=1 python examples/kl_div.py
Expected behavior
No crash; triton code should be outputted
Versions
PyTorch and helion from source.
Additional context
N/A
Metadata
Metadata
Assignees
Labels
No labels