Skip to content

Unable to run the exponent example or other pointwise kernels #1039

@narendasan

Description

@narendasan

Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.

I am trying to write a simple pointwise kernel, but am running into issues even running the exp kernel example.

Describe the bug
I am trying to run simple kernel examples and am hitting issues loading and storing tiles

[0s] Starting autotuning process, this may take a while...
[0s] Starting PatternSearch with initial_population=100, copies=5, max_generations=20
Helion compiler triton codegen error for @helion.kernel(config=helion.Config(block_sizes=[1, 128], flatten_loops=[True], indexing=['block_ptr', 'pointer'], l2_groupings=[4], load_eviction_policies=['first'], loop_orders=[[1, 0]], num_stages=8, num_warps=32, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 755, in codegen
    return self.api_func._codegen(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/language/memory_ops.py", line 277, in _
    return strategy.codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 203, in codegen_load
    return PointerIndexingStrategy().codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 153, in codegen_load
    indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 703, in create
    expand = tile_strategy.expand_str(output_size, output_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 163, in expand_str
    compacted_shapes = self._compact_shape(shape)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 133, in _compact_shape
    compacted_shapes = strategy.compact_shape(compacted_shapes)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_strategy.py", line 525, in compact_shape
    assert shape.block_ids[0] == self.block_ids[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 448, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 419, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 463, in generate_ast
    codegen.add_statement(codegen.visit(stmt))
                          ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/ast_extension.py", line 277, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 290, in visit_For
    codegen_call_with_graph(
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1380, in run_node
    raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node load (<function load at 0x7d93a127f740>):
While processing:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])
                          ^^^^^^^


While executing %load : [num_users=1] = call_function[target=helion.language.memory_ops.load](args = (%x, [%block_size_0, %block_size_1], None, None), kwargs = {})
Original traceback:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1337, in run_node
    result = lowering.codegen(self, n)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 755, in codegen
    return self.api_func._codegen(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/language/memory_ops.py", line 277, in _
    return strategy.codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 203, in codegen_load
    return PointerIndexingStrategy().codegen_load(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 153, in codegen_load
    indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/indexing_strategy.py", line 703, in create
    expand = tile_strategy.expand_str(output_size, output_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 163, in expand_str
    compacted_shapes = self._compact_shape(shape)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_dispatch.py", line 133, in _compact_shape
    compacted_shapes = strategy.compact_shape(compacted_shapes)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/tile_strategy.py", line 525, in compact_shape
    assert shape.block_ids[0] == self.block_ids[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 25, in <module>
    helion_silu(*example_inputs)
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 292, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 636, in __call__
    self.autotune(args, force=False)
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 519, in autotune
    config = self.settings.autotuner_fn(self, args, **kwargs).autotune(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_cache.py", line 234, in autotune
    config = self.autotuner.autotune()
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 536, in autotune
    best = self._autotune()
           ^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/pattern_search.py", line 64, in _autotune
    self.parallel_benchmark_population(self.population, desc="Initial population")
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 717, in parallel_benchmark_population
    self.parallel_benchmark([m.config for m in members], desc=desc),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 459, in parallel_benchmark
    fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 459, in <listcomp>
    fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 448, in compile_config
    triton_code = self.to_triton_code(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 419, in to_triton_code
    root = generate_ast(self.host_function, config, emit_repro_caller)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 463, in generate_ast
    codegen.add_statement(codegen.visit(stmt))
                          ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/ast_extension.py", line 277, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/generate_ast.py", line 290, in visit_For
    codegen_call_with_graph(
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1410, in codegen_call_with_graph
    return GraphInterpreter(graph, cg).run(*new_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/naren/Downloads/pointwise_silu_repro (2)/.venv/lib/python3.11/site-packages/helion/_compiler/inductor_lowering.py", line 1380, in run_node
    raise InductorLoweringError(
helion.exc.InductorLoweringError: Error in codegen for node load (<function load at 0x7d93a127f740>):
While processing:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])
                          ^^^^^^^


While executing %load : [num_users=1] = call_function[target=helion.language.memory_ops.load](args = (%x, [%block_size_0, %block_size_1], None, None), kwargs = {})
Original traceback:
  File "/home/naren/Downloads/pointwise_silu_repro (2)/python/test.py", line 19, in helion_silu
    out[tile] = torch.exp(x[tile])

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

I can work around these issues with explicit hl.load and hl.store calls but then I hit triton code gen issues.

To Reproduce
Steps to reproduce the behavior.

Run this script: https://helionlang.com/examples/exp.html

or alternatively

import torch
import helion
import helion.language as hl


@helion.kernel(static_shapes=False)
def helion_silu(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x, dtype=x.dtype, device=x.device)
    for tile in hl.tile(out.size()):
        out[tile] = x[tile] * torch.sigmoid(x[tile])
    return out


# Preform autotuning for the kernel
example_inputs = [torch.randn((1, 100), dtype=torch.float16, device="cuda")]
helion_silu(*example_inputs)

Expected behavior

I am able to run this example

Versions
PyTorch/Triton/Helion versions and any other relevant library version.

dllist==2.0.0
filecheck==1.0.3
filelock==3.20.0
fsspec==2025.9.0
helion==0.2.1
iniconfig==2.3.0
jinja2==3.1.6
markdown-it-py==4.0.0
markupsafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
networkx==3.5
numpy==2.3.4
nvidia-cublas==13.0.0.19
nvidia-cuda-cupti==13.0.48
nvidia-cuda-nvrtc==13.0.48
nvidia-cuda-runtime==13.0.48
nvidia-cuda-runtime-cu13==0.0.0a0
nvidia-cudnn-cu13==9.13.0.50
nvidia-cufft==12.0.0.15
nvidia-cufile==1.15.0.42
nvidia-curand==10.4.0.35
nvidia-cusolver==12.0.3.29
nvidia-cusparse==12.6.2.49
nvidia-cusparselt-cu13==0.8.0
nvidia-nccl-cu13==2.27.7
nvidia-nvjitlink==13.0.39
nvidia-nvshmem-cu13==3.3.24
nvidia-nvtx==13.0.39
packaging==25.0
pluggy==1.6.0
-e file:///home/naren/Downloads/pointwise_silu_repro%20(2)
psutil==6.0.0
pygments==2.19.2
pytest==8.4.2
pytorch-triton==3.5.0+git7416ffcb
rich==14.1.0
sympy==1.14.0
tensorrt==10.13.3.9
tensorrt-cu13==10.13.3.9
tensorrt-cu13-bindings==10.13.3.9
tensorrt-cu13-libs==10.13.3.9
torch==2.10.0.dev20251027+cu130
torch-tensorrt==2.10.0.dev20251026+cu130
tqdm==4.67.1
typing-extensions==4.15.0

Additional context
Is the issue that I am using the nightly torch stack?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions