Skip to content

Commit 2f765f1

Browse files
committed
up
1 parent b82a244 commit 2f765f1

File tree

9 files changed

+108
-6
lines changed

9 files changed

+108
-6
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU
301301
helpful for debugging and understanding Helion's compilation process. One can also use
302302
`foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string.
303303

304+
To emit a repro script that includes the Helion kernel definition, the config decorator, and a
305+
`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set
306+
`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints
307+
the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker.
308+
304309
Within an `hl.tile`/`hl.grid` device loop, if you want to print intermediate results using `print("x", ...)` syntax,
305310
or pause execution using Python's built-in `breakpoint()`, set either `TRITON_INTERPRET=1` (runs Triton's CPU interpreter)
306311
or `HELION_INTERPRET=1` (runs the Helion kernel in eager mode).

docs/api/config.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The `Config` class represents kernel optimization parameters that control how He
2727
|--------|--------|----------|
2828
| **Purpose** | Control execution performance | Control compilation behavior |
2929
| **Autotuning** | ✅ Automatically optimized | ❌ Never autotuned |
30-
| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `autotune_effort` |
30+
| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `print_repro`, `autotune_effort` |
3131
| **When to use** | Performance optimization | Development, debugging, environment setup |
3232

3333

docs/api/kernel.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ Settings control **how the kernel is compiled** and the development environment:
161161
autotune_effort="none", # Skip autotuning for development
162162
autotune_effort="quick", # Smaller autotuning budget when search is enabled
163163
print_output_code=True, # Debug: show generated Triton code
164+
print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script
164165
static_shapes=True, # Compilation optimization strategy
165166
autotune_log_level=logging.DEBUG # Verbose autotuning output
166167
)

docs/api/settings.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ import helion.language as hl
6161

6262
@helion.kernel(
6363
autotune_effort="none", # Skip autotuning
64-
print_output_code=True, # Debug output
64+
print_output_code=True, # Debug: show generated Triton code
65+
print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script
6566
)
6667
def my_kernel(x: torch.Tensor) -> torch.Tensor:
6768
result = torch.zeros_like(x)
@@ -190,6 +191,10 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b
190191
191192
Print generated Triton code to stderr. Default is ``False``. Controlled by ``HELION_PRINT_OUTPUT_CODE=1``.
192193
194+
.. autoattribute:: Settings.print_repro
195+
196+
Print Helion kernel code, config, and caller code to stderr as a standalone repro script. Default is ``False``. Controlled by ``HELION_PRINT_REPRO=1``.
197+
193198
.. autoattribute:: Settings.output_origin_lines
194199
195200
Annotate generated Triton code with ``# src[<file>:<line>]`` comments indicating the originating Helion statements.
@@ -259,6 +264,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
259264
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
260265
| ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. |
261266
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |
267+
| ``HELION_PRINT_REPRO`` | ``print_repro`` | Print Helion kernel code, config, and caller code to stderr as a standalone repro script. |
262268
| ``HELION_OUTPUT_ORIGIN_LINES`` | ``output_origin_lines`` | Include ``# src[...]`` comments in generated Triton code; set to ``0`` to disable. |
263269
| ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. |
264270
| ``HELION_ALLOW_WARP_SPECIALIZE`` | ``allow_warp_specialize`` | Permit warp-specialized code generation for ``tl.range``. |

docs/index.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU
241241
helpful for debugging and understanding Helion's compilation process. One can also use
242242
`foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string.
243243

244+
To emit a repro script that includes the Helion kernel definition, the config decorator, and a
245+
`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set
246+
`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints
247+
the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker.
248+
244249
To force autotuning, bypassing provided configurations, set `HELION_FORCE_AUTOTUNE=1` or invoke `foo_kernel.autotune(args,
245250
force=True)`.
246251

helion/runtime/kernel.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import operator
1010
import re
1111
import sys
12+
import textwrap
1213
import types
1314
from typing import TYPE_CHECKING
1415
from typing import Callable
@@ -641,8 +642,88 @@ def __call__(self, *args: object) -> _R:
641642
self.format_kernel_decorator(self._config, self.settings)
642643
] = 1
643644

645+
if self.settings.print_repro:
646+
self._print_repro(args)
647+
644648
return self._run(*args)
645649

650+
def _print_repro(
651+
self, args: tuple[object, ...], config: Config | None = None
652+
) -> None:
653+
effective_config = config or self._config
654+
assert effective_config is not None
655+
656+
# Get kernel source
657+
try:
658+
raw_source = inspect.getsource(self.kernel.fn)
659+
source_lines = textwrap.dedent(raw_source).splitlines()
660+
# Skip decorator lines
661+
start_idx = 0
662+
while start_idx < len(source_lines) and source_lines[
663+
start_idx
664+
].lstrip().startswith("@"):
665+
start_idx += 1
666+
kernel_body = "\n".join(source_lines[start_idx:])
667+
except (OSError, TypeError):
668+
kernel_body = f"# Source unavailable for {self.kernel.fn.__module__}.{self.kernel.fn.__qualname__}"
669+
670+
# Format decorator
671+
decorator = self.format_kernel_decorator(effective_config, self.settings)
672+
673+
# Build output
674+
output_lines = [
675+
"# === HELION KERNEL REPRO ===",
676+
"import helion",
677+
"import helion.language as hl",
678+
"import torch",
679+
"from torch._dynamo.testing import rand_strided",
680+
"",
681+
decorator,
682+
kernel_body,
683+
]
684+
685+
# Generate caller function
686+
if args:
687+
688+
def _render_input_arg_assignment(name: str, value: object) -> list[str]:
689+
if isinstance(value, torch.Tensor):
690+
shape = tuple(int(d) for d in value.shape)
691+
stride = tuple(int(s) for s in value.stride())
692+
device = str(value.device)
693+
dtype = str(value.dtype)
694+
695+
lines = [
696+
f"{name} = rand_strided({shape!r}, {stride!r}, dtype={dtype}, device={device!r})"
697+
]
698+
699+
if value.requires_grad:
700+
lines.append(f"{name}.requires_grad_(True)")
701+
return lines
702+
703+
return [f"{name} = {value!r}"]
704+
705+
sig_param_names = list(self.kernel.signature.parameters.keys())
706+
assert len(args) == len(sig_param_names)
707+
708+
output_lines.extend(["", "def helion_repro_caller():"])
709+
output_lines.append(" torch.manual_seed(0)")
710+
arg_names = []
711+
712+
for i, value in enumerate(args):
713+
var_name = sig_param_names[i]
714+
arg_names.append(var_name)
715+
716+
# Add assignment lines with indentation
717+
for line in _render_input_arg_assignment(var_name, value):
718+
output_lines.append(f" {line}")
719+
720+
# Add return statement
721+
call_args = ", ".join(arg_names)
722+
output_lines.append(f" return {self.kernel.name}({call_args})")
723+
724+
output_lines.append("# === END HELION KERNEL REPRO ===")
725+
print("\n".join(output_lines), file=sys.stderr)
726+
646727

647728
class _KernelDecorator(Protocol):
648729
def __call__(

helion/runtime/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ class _Settings:
314314
_env_get_bool, "HELION_PRINT_OUTPUT_CODE", False
315315
)
316316
)
317+
print_repro: bool = dataclasses.field(
318+
default_factory=functools.partial(_env_get_bool, "HELION_PRINT_REPRO", False)
319+
)
317320
output_origin_lines: bool = dataclasses.field(
318321
default_factory=functools.partial(
319322
_env_get_bool, "HELION_OUTPUT_ORIGIN_LINES", True
@@ -384,6 +387,7 @@ class Settings(_Settings):
384387
"Set HELION_AUTOTUNE_IGNORE_ERRORS=1 to enable globally."
385388
),
386389
"print_output_code": "If True, print the output code of the kernel to stderr.",
390+
"print_repro": "If True, print Helion kernel code, config, and caller code to stderr as a standalone repro script.",
387391
"output_origin_lines": (
388392
"If True, annotate generated Triton code with source-origin comments. "
389393
"Set HELION_OUTPUT_ORIGIN_LINES=0 to disable."

test/test_debug_utils.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import torch
88
from torch._dynamo.testing import rand_strided
99

1010
@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[None]), static_shapes=True)
11-
def repro_kernel(x: torch.Tensor) -> torch.Tensor:
11+
def kernel1(x: torch.Tensor) -> torch.Tensor:
1212
out = torch.empty_like(x)
1313
m, n = x.shape
1414
for tile_m, tile_n in hl.tile([m, n]):
@@ -18,4 +18,4 @@ def repro_kernel(x: torch.Tensor) -> torch.Tensor:
1818
def helion_repro_caller():
1919
torch.manual_seed(0)
2020
x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE)
21-
return repro_kernel(x)
21+
return kernel1(x)

test/test_debug_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_print_repro_env_var(self):
3232
try:
3333

3434
@helion.kernel(autotune_effort="none")
35-
def repro_kernel(x: torch.Tensor) -> torch.Tensor:
35+
def kernel1(x: torch.Tensor) -> torch.Tensor:
3636
out = torch.empty_like(x)
3737
m, n = x.shape
3838
for tile_m, tile_n in hl.tile([m, n]):
@@ -45,7 +45,7 @@ def repro_kernel(x: torch.Tensor) -> torch.Tensor:
4545
if hasattr(self, "_capfd"):
4646
self._capfd.readouterr()
4747

48-
result = repro_kernel(x)
48+
result = kernel1(x)
4949
torch.testing.assert_close(result, x + 1)
5050

5151
if not hasattr(self, "_capfd"):

0 commit comments

Comments
 (0)