- 
                Notifications
    
You must be signed in to change notification settings  - Fork 61
 
Description
Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.
I am trying to write a simple pointwise kernel, but am running into issues even running the exp kernel example.
Describe the bug
I am trying to run simple kernel examples and am hitting issues loading and storing tiles
[0s] Starting autotuning process, this may take a while...
[0s] Starting PatternSearch with initial_population=100, copies=5, max_generations=20
Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[1, 128], flatten_loops=[True], indexing=['block_ptr', 'pointer'], l2_groupings=[4], load_eviction_policies=['first'], loop_orders=[[1, 0]], num_stages=8, num_warps=32, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 755, in codegen
    return self.api_func._codegen(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/language/memory_ops.py", line 277, in _
    return strategy.codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 203, in codegen_load
    return PointerIndexingStrategy().codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 153, in codegen_load
    indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 703, in create
    expand = tile_strategy.expand_str(output_size, output_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 163, in expand_str
    compacted_shapes = self._compact_shape(shape)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 133, in _compact_shape
    compacted_shapes = strategy.compact_shape(compacted_shapes)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_strategy.py", line 525, in compact_shape
    assert shape.block_ids[0] == self.block_ids[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 448, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 419, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 463, in generate_ast
    codegen.add_statement(codegen.visit(stmt))
                          ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/ast_extension.py", line 277, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 290, in visit_For
    codegen_call_with_graph(
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1380, in run_node
    raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node load (<function load at 0x7d93a127f740>):
While processing:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])
                          ^^^^^^^
While executing %load : [num_users=1] = call_function[target=helion.language.memory_ops.load](args = (%x, [%block_size_0, %block_size_1], None, None), kwargs = {})
Original traceback:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 755, in codegen
    return self.api_func._codegen(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/language/memory_ops.py", line 277, in _
    return strategy.codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 203, in codegen_load
    return PointerIndexingStrategy().codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 153, in codegen_load
    indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 703, in create
    expand = tile_strategy.expand_str(output_size, output_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 163, in expand_str
    compacted_shapes = self._compact_shape(shape)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 133, in _compact_shape
    compacted_shapes = strategy.compact_shape(compacted_shapes)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_strategy.py", line 525, in compact_shape
    assert shape.block_ids[0] == self.block_ids[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 25, in <module>
    helion_silu(*example_inputs)
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 292, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 636, in __call__
    self.autotune(args, force=False)
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 519, in autotune
    config = self.settings.autotuner_fn(self, args, **kwargs).autotune(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_cache.py", line 234, in autotune
    config = self.autotuner.autotune()
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 536, in autotune
    best = self._autotune()
           ^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/pattern_search.py", line 64, in _autotune
    self.parallel_benchmark_population(self.population, desc="Initial population")
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 717, in parallel_benchmark_population
    self.parallel_benchmark([m.config for m in members], desc=desc),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 459, in parallel_benchmark
    fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 459, in <listcomp>
    fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 448, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 419, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 463, in generate_ast
    codegen.add_statement(codegen.visit(stmt))
                          ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/ast_extension.py", line 277, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 290, in visit_For
    codegen_call_with_graph(
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1380, in run_node
    raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node load (<function load at 0x7d93a127f740>):
While processing:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])
                          ^^^^^^^
While executing %load : [num_users=1] = call_function[target=helion.language.memory_ops.load](args = (%x, [%block_size_0, %block_size_1], None, None), kwargs = {})
Original traceback:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
I can work around these issues with explicit hl.load and hl.store calls but then I hit triton code gen issues.
To Reproduce
Steps to reproduce the behavior.
Run this script: https://helionlang.com/examples/exp.html
or alternatively
import torch
import helion
import helion.language as hl
@helion.kernel(static_shapes=False)
def helion_silu(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x, dtype=x.dtype, device=x.device)
    for tile in hl.tile(out.size()):
        out[tile] = x[tile] * torch.sigmoid(x[tile])
    return out
# Preform autotuning for the kernel
example_inputs = [torch.randn((1, 100), dtype=torch.float16, device="cuda")]
helion_silu(*example_inputs)Expected behavior
I am able to run this example
Versions
PyTorch/Triton/Helion versions and any other relevant library version.
dllist==2.0.0
filecheck==1.0.3
filelock==3.20.0
fsspec==2025.9.0
helion==0.2.1
iniconfig==2.3.0
jinja2==3.1.6
markdown-it-py==4.0.0
markupsafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
networkx==3.5
numpy==2.3.4
nvidia-cublas==13.0.0.19
nvidia-cuda-cupti==13.0.48
nvidia-cuda-nvrtc==13.0.48
nvidia-cuda-runtime==13.0.48
nvidia-cuda-runtime-cu13==0.0.0a0
nvidia-cudnn-cu13==9.13.0.50
nvidia-cufft==12.0.0.15
nvidia-cufile==1.15.0.42
nvidia-curand==10.4.0.35
nvidia-cusolver==12.0.3.29
nvidia-cusparse==12.6.2.49
nvidia-cusparselt-cu13==0.8.0
nvidia-nccl-cu13==2.27.7
nvidia-nvjitlink==13.0.39
nvidia-nvshmem-cu13==3.3.24
nvidia-nvtx==13.0.39
packaging==25.0
pluggy==1.6.0
-e file:///home/naren/Downloads/pointwise_silu_repro%20(2)
psutil==6.0.0
pygments==2.19.2
pytest==8.4.2
pytorch-triton==3.5.0+git7416ffcb
rich==14.1.0
sympy==1.14.0
tensorrt==10.13.3.9
tensorrt-cu13==10.13.3.9
tensorrt-cu13-bindings==10.13.3.9
tensorrt-cu13-libs==10.13.3.9
torch==2.10.0.dev20251027+cu130
torch-tensorrt==2.10.0.dev20251026+cu130
tqdm==4.67.1
typing-extensions==4.15.0
Additional context
Is the issue that I am using the nightly torch stack?