Skip to content

HELION_PRINT_OUTPUT_CODE=1 fails with SyntaxError #803

@williamwen42

Description

@williamwen42

Describe the bug

Getting the following error when using HELION_PRINT_OUTPUT_CODE=1

Testing helion correctness...
Using default config: @helion.kernel(config=helion.Config(block_sizes=[32, 32], indexing='pointer', num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=False)
Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[32, 32], indexing='pointer', num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=False)
Traceback (most recent call last):
  File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 434, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 409, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/_compiler/generate_ast.py", line 424, in generate_ast
    call_def = [func.codegen_call_function()]
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/_compiler/host_function.py", line 282, in codegen_call_function
    inits.append(statement_from_string(f"{name} = {rhs}"))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/_compiler/ast_extension.py", line 206, in statement_from_string
    (statement,) = ast.parse(modified_template).body
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/py312-env/lib/python3.12/ast.py", line 53, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    grad_out = rand_strided(size=(,), stride=(,), dtype=torch.float32, device='cuda:0')
                                  ^
SyntaxError: invalid syntax
Traceback (most recent call last):
  File "/data/users/williamwen/helion/examples/kl_div.py", line 371, in <module>
    main()
  File "/data/users/williamwen/helion/examples/kl_div.py", line 365, in main
    check_kl_div_kernel(B, T, V, reduction, log_target, eps)
  File "/data/users/williamwen/helion/examples/kl_div.py", line 316, in check_kl_div_kernel
    run_example(helion_wrapper, baseline_wrapper, create_inputs(), bwd=True)
  File "/data/users/williamwen/helion/helion/_testing.py", line 519, in run_example
    out.backward(grad_output, retain_graph=True)
  File "/data/users/williamwen/pytorch/torch/_tensor.py", line 625, in backward
    torch.autograd.backward(
  File "/data/users/williamwen/pytorch/torch/autograd/__init__.py", line 354, in backward
    _engine_run_backward(
  File "/data/users/williamwen/pytorch/torch/autograd/graph.py", line 841, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/pytorch/torch/autograd/function.py", line 315, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/examples/kl_div.py", line 209, in backward
    grad_y_pred, grad_y_true = kl_div_backward(
                               ^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 286, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 618, in __call__
    self.set_config(config)
  File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 523, in set_config
    self._run = self.compile_config(config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 434, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/runtime/kernel.py", line 409, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/_compiler/generate_ast.py", line 424, in generate_ast
    call_def = [func.codegen_call_function()]
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/_compiler/host_function.py", line 282, in codegen_call_function
    inits.append(statement_from_string(f"{name} = {rhs}"))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/helion/helion/_compiler/ast_extension.py", line 206, in statement_from_string
    (statement,) = ast.parse(modified_template).body
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/williamwen/py312-env/lib/python3.12/ast.py", line 53, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    grad_out = rand_strided(size=(,), stride=(,), dtype=torch.float32, device='cuda:0')
                                  ^
SyntaxError: invalid syntax

To Reproduce
Checkout #802, run HELION_PRINT_OUTPUT_CODE=1 HELION_USE_DEFAULT_CONFIG=1 python examples/kl_div.py

Expected behavior
No crash; triton code should be outputted

Versions
PyTorch and helion from source.

Additional context
N/A

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions