Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU
helpful for debugging and understanding Helion's compilation process. One can also use
`foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string.

To emit a repro script that includes the Helion kernel definition, the config decorator, and a
`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set
`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints
the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker.

Within an `hl.tile`/`hl.grid` device loop, if you want to print intermediate results using `print("x", ...)` syntax,
or pause execution using Python's built-in `breakpoint()`, set either `TRITON_INTERPRET=1` (runs Triton's CPU interpreter)
or `HELION_INTERPRET=1` (runs the Helion kernel in eager mode).
Expand Down
2 changes: 1 addition & 1 deletion docs/api/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The `Config` class represents kernel optimization parameters that control how He
|--------|--------|----------|
| **Purpose** | Control execution performance | Control compilation behavior |
| **Autotuning** | ✅ Automatically optimized | ❌ Never autotuned |
| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `autotune_effort` |
| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `print_repro`, `autotune_effort` |
| **When to use** | Performance optimization | Development, debugging, environment setup |


Expand Down
1 change: 1 addition & 0 deletions docs/api/kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ Settings control **how the kernel is compiled** and the development environment:
autotune_effort="none", # Skip autotuning for development
autotune_effort="quick", # Smaller autotuning budget when search is enabled
print_output_code=True, # Debug: show generated Triton code
print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script
static_shapes=True, # Compilation optimization strategy
autotune_log_level=logging.DEBUG # Verbose autotuning output
)
Expand Down
8 changes: 7 additions & 1 deletion docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ import helion.language as hl

@helion.kernel(
autotune_effort="none", # Skip autotuning
print_output_code=True, # Debug output
print_output_code=True, # Debug: show generated Triton code
print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script
)
def my_kernel(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
Expand Down Expand Up @@ -190,6 +191,10 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b

Print generated Triton code to stderr. Default is ``False``. Controlled by ``HELION_PRINT_OUTPUT_CODE=1``.

.. autoattribute:: Settings.print_repro

Print Helion kernel code, config, and caller code to stderr as a standalone repro script. Default is ``False``. Controlled by ``HELION_PRINT_REPRO=1``.

.. autoattribute:: Settings.output_origin_lines

Annotate generated Triton code with ``# src[<file>:<line>]`` comments indicating the originating Helion statements.
Expand Down Expand Up @@ -259,6 +264,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
| ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. |
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |
| ``HELION_PRINT_REPRO`` | ``print_repro`` | Print Helion kernel code, config, and caller code to stderr as a standalone repro script. |
| ``HELION_OUTPUT_ORIGIN_LINES`` | ``output_origin_lines`` | Include ``# src[...]`` comments in generated Triton code; set to ``0`` to disable. |
| ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. |
| ``HELION_ALLOW_WARP_SPECIALIZE`` | ``allow_warp_specialize`` | Permit warp-specialized code generation for ``tl.range``. |
Expand Down
5 changes: 5 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU
helpful for debugging and understanding Helion's compilation process. One can also use
`foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string.

To emit a repro script that includes the Helion kernel definition, the config decorator, and a
`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set
`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints
the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker.

To force autotuning, bypassing provided configurations, set `HELION_FORCE_AUTOTUNE=1` or invoke `foo_kernel.autotune(args,
force=True)`.

Expand Down
81 changes: 81 additions & 0 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import operator
import re
import sys
import textwrap
import types
from typing import TYPE_CHECKING
from typing import Callable
Expand Down Expand Up @@ -641,8 +642,88 @@ def __call__(self, *args: object) -> _R:
self.format_kernel_decorator(self._config, self.settings)
] = 1

if self.settings.print_repro:
self._print_repro(args)

return self._run(*args)

def _print_repro(
Copy link
Contributor

@oulgen oulgen Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think you need any of this, i added a mode to to_triton_code that prints repro?

there's even a unit test for it

Copy link
Contributor Author

@yf225 yf225 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah so I actually want to get the Helion kernel (with config and caller) not the generated Triton kernel.

what internal folks usually provide is a buck command that's many layers deep, and it's very hard to extract out a clean Helion kernel repro. With this env var, it will print the Helion kernel with the right input tensors, so that I can just focus on debugging the minimal repro.

self, args: tuple[object, ...], config: Config | None = None
) -> None:
effective_config = config or self._config
assert effective_config is not None

# Get kernel source
try:
raw_source = inspect.getsource(self.kernel.fn)
source_lines = textwrap.dedent(raw_source).splitlines()
# Skip decorator lines (including multi-line decorators)
start_idx = 0
while start_idx < len(source_lines) and not source_lines[
start_idx
].lstrip().startswith("def "):
start_idx += 1
kernel_body = "\n".join(source_lines[start_idx:])
except (OSError, TypeError):
kernel_body = f"# Source unavailable for {self.kernel.fn.__module__}.{self.kernel.fn.__qualname__}"

# Format decorator
decorator = self.format_kernel_decorator(effective_config, self.settings)

# Build output
output_lines = [
"# === HELION KERNEL REPRO ===",
"import helion",
"import helion.language as hl",
"import torch",
"from torch._dynamo.testing import rand_strided",
"",
decorator,
kernel_body,
]

# Generate caller function
if args:

def _render_input_arg_assignment(name: str, value: object) -> list[str]:
if isinstance(value, torch.Tensor):
shape = tuple(int(d) for d in value.shape)
stride = tuple(int(s) for s in value.stride())
device = str(value.device)
dtype = str(value.dtype)

lines = [
f"{name} = rand_strided({shape!r}, {stride!r}, dtype={dtype}, device={device!r})"
]

if value.requires_grad:
lines.append(f"{name}.requires_grad_(True)")
return lines

return [f"{name} = {value!r}"]

sig_param_names = list(self.kernel.signature.parameters.keys())
assert len(args) == len(sig_param_names)

output_lines.extend(["", "def helion_repro_caller():"])
output_lines.append(" torch.manual_seed(0)")
arg_names = []

for i, value in enumerate(args):
var_name = sig_param_names[i]
arg_names.append(var_name)

# Add assignment lines with indentation
for line in _render_input_arg_assignment(var_name, value):
output_lines.append(f" {line}")

# Add return statement
call_args = ", ".join(arg_names)
output_lines.append(f" return {self.kernel.name}({call_args})")

output_lines.append("# === END HELION KERNEL REPRO ===")
print("\n".join(output_lines), file=sys.stderr)


class _KernelDecorator(Protocol):
def __call__(
Expand Down
4 changes: 4 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ class _Settings:
_env_get_bool, "HELION_PRINT_OUTPUT_CODE", False
)
)
print_repro: bool = dataclasses.field(
default_factory=functools.partial(_env_get_bool, "HELION_PRINT_REPRO", False)
)
output_origin_lines: bool = dataclasses.field(
default_factory=functools.partial(
_env_get_bool, "HELION_OUTPUT_ORIGIN_LINES", True
Expand Down Expand Up @@ -386,6 +389,7 @@ class Settings(_Settings):
"Set HELION_AUTOTUNE_IGNORE_ERRORS=1 to enable globally."
),
"print_output_code": "If True, print the output code of the kernel to stderr.",
"print_repro": "If True, print Helion kernel code, config, and caller code to stderr as a standalone repro script.",
"output_origin_lines": (
"If True, annotate generated Triton code with source-origin comments. "
"Set HELION_OUTPUT_ORIGIN_LINES=0 to disable."
Expand Down
23 changes: 23 additions & 0 deletions test/test_debug_utils.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
This file is automatically generated by assertExpectedJournal calls in test_debug_utils.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestDebugUtils.test_print_repro_env_var)
# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
def kernel1(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
m, n = x.shape
for tile_m, tile_n in hl.tile([m, n]):
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
return out

def helion_repro_caller():
torch.manual_seed(0)
x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE)
return kernel1(x)
# === END HELION KERNEL REPRO ===
132 changes: 132 additions & 0 deletions test/test_debug_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import linecache
import os
import unittest

import pytest
import torch

import helion
from helion._testing import DEVICE
from helion._testing import RefEagerTestDisabled
from helion._testing import TestCase
import helion.language as hl


@pytest.fixture(autouse=True)
def _store_capfd_on_class(request, capfd):
"""
Expose pytest's capfd fixture as `self._capfd` inside the TestDebugUtils class
(works for unittest.TestCase-style tests).
"""
if request.cls is not None:
request.cls._capfd = capfd


class TestDebugUtils(RefEagerTestDisabled, TestCase):
def test_print_repro_env_var(self):
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
original = os.environ.get("HELION_PRINT_REPRO")
os.environ["HELION_PRINT_REPRO"] = "1"
try:

@helion.kernel(
config=helion.Config(
block_sizes=[2, 2],
flatten_loops=[False],
indexing=["pointer", "pointer"],
l2_groupings=[1],
load_eviction_policies=[""],
loop_orders=[[0, 1]],
num_stages=1,
num_warps=4,
pid_type="flat",
range_flattens=[None],
range_multi_buffers=[None],
range_num_stages=[0],
range_unroll_factors=[0],
),
static_shapes=True,
)
def kernel1(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
m, n = x.shape
for tile_m, tile_n in hl.tile([m, n]):
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
return out

torch.manual_seed(0)
x = torch.randn([2, 2], dtype=torch.float32, device=DEVICE)

if hasattr(self, "_capfd"):
self._capfd.readouterr()

result = kernel1(x)
torch.testing.assert_close(result, x + 1)

if not hasattr(self, "_capfd"):
return # Cannot test without capture

captured = "".join(self._capfd.readouterr())

# Extract repro script
lines = captured.splitlines()
start = next(
i
for i, line in enumerate(lines)
if "# === HELION KERNEL REPRO ===" in line
)
end = next(
i
for i, line in enumerate(lines[start:], start)
if "# === END HELION KERNEL REPRO ===" in line
)
repro_script = "\n".join(lines[start : end + 1])

# Normalize range_warp_specializes=[None] to [] for comparison
normalized_script = repro_script.replace(
"range_warp_specializes=[None]", "range_warp_specializes=[]"
)

# Verify repro script matches expected script
self.assertExpectedJournal(normalized_script)

# Extract the actual code (without the comment markers) for execution
repro_lines = repro_script.splitlines()
code_start = 1 if repro_lines[0].startswith("# === HELION") else 0
code_end = len(repro_lines) - (
1 if repro_lines[-1].startswith("# === END") else 0
)
repro_code = "\n".join(repro_lines[code_start:code_end])

# Setup linecache so inspect.getsource() works on exec'd code
filename = "<helion_repro_test>"
linecache.cache[filename] = (
len(repro_code),
None,
[f"{line}\n" for line in repro_code.splitlines()],
filename,
)

# Execute the repro script
namespace = {}
exec(compile(repro_code, filename, "exec"), namespace)

# Call the generated helper and verify it runs successfully
helper = namespace["helion_repro_caller"]
repro_result = helper()

# Verify the output
torch.testing.assert_close(repro_result, x + 1)

linecache.cache.pop(filename, None)
finally:
if original is None:
os.environ.pop("HELION_PRINT_REPRO", None)
else:
os.environ["HELION_PRINT_REPRO"] = original


if __name__ == "__main__":
unittest.main()
Loading