From 4388ab1a8fdea83d647d8bf52d0a9980381f51fc Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 30 Oct 2025 13:17:18 -0700 Subject: [PATCH 1/2] test --- test/test_debug_utils.expected | 23 ++++++ test/test_debug_utils.py | 132 +++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 test/test_debug_utils.expected create mode 100644 test/test_debug_utils.py diff --git a/test/test_debug_utils.expected b/test/test_debug_utils.expected new file mode 100644 index 000000000..93057aad2 --- /dev/null +++ b/test/test_debug_utils.expected @@ -0,0 +1,23 @@ +This file is automatically generated by assertExpectedJournal calls in test_debug_utils.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestDebugUtils.test_print_repro_env_var) +# === HELION KERNEL REPRO === +import helion +import helion.language as hl +import torch +from torch._dynamo.testing import rand_strided + +@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) +def kernel1(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + out[tile_m, tile_n] = x[tile_m, tile_n] + 1 + return out + +def helion_repro_caller(): + torch.manual_seed(0) + x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE) + return kernel1(x) +# === END HELION KERNEL REPRO === diff --git a/test/test_debug_utils.py b/test/test_debug_utils.py new file mode 100644 index 000000000..53628db71 --- /dev/null +++ b/test/test_debug_utils.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import linecache +import os +import unittest + +import pytest +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled +from helion._testing import TestCase +import helion.language as hl + + +@pytest.fixture(autouse=True) +def _store_capfd_on_class(request, capfd): + """ + Expose pytest's capfd fixture as `self._capfd` inside the TestDebugUtils class + (works for unittest.TestCase-style tests). + """ + if request.cls is not None: + request.cls._capfd = capfd + + +class TestDebugUtils(RefEagerTestDisabled, TestCase): + def test_print_repro_env_var(self): + """Ensure HELION_PRINT_REPRO=1 emits an executable repro script.""" + original = os.environ.get("HELION_PRINT_REPRO") + os.environ["HELION_PRINT_REPRO"] = "1" + try: + + @helion.kernel( + config=helion.Config( + block_sizes=[2, 2], + flatten_loops=[False], + indexing=["pointer", "pointer"], + l2_groupings=[1], + load_eviction_policies=[""], + loop_orders=[[0, 1]], + num_stages=1, + num_warps=4, + pid_type="flat", + range_flattens=[None], + range_multi_buffers=[None], + range_num_stages=[0], + range_unroll_factors=[0], + ), + static_shapes=True, + ) + def kernel1(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + out[tile_m, tile_n] = x[tile_m, tile_n] + 1 + return out + + torch.manual_seed(0) + x = torch.randn([2, 2], dtype=torch.float32, device=DEVICE) + + if hasattr(self, "_capfd"): + self._capfd.readouterr() + + result = kernel1(x) + torch.testing.assert_close(result, x + 1) + + if not hasattr(self, "_capfd"): + return # Cannot test without capture + + captured = "".join(self._capfd.readouterr()) + + # Extract repro script + lines = captured.splitlines() + start = next( + i + for i, line in enumerate(lines) + if "# === HELION KERNEL REPRO ===" in line + ) + end = next( + i + for i, line in enumerate(lines[start:], start) + if "# === END HELION KERNEL REPRO ===" in line + ) + repro_script = "\n".join(lines[start : end + 1]) + + # Normalize range_warp_specializes=[None] to [] for comparison + normalized_script = repro_script.replace( + "range_warp_specializes=[None]", "range_warp_specializes=[]" + ) + + # Verify repro script matches expected script + self.assertExpectedJournal(normalized_script) + + # Extract the actual code (without the comment markers) for execution + repro_lines = repro_script.splitlines() + code_start = 1 if repro_lines[0].startswith("# === HELION") else 0 + code_end = len(repro_lines) - ( + 1 if repro_lines[-1].startswith("# === END") else 0 + ) + repro_code = "\n".join(repro_lines[code_start:code_end]) + + # Setup linecache so inspect.getsource() works on exec'd code + filename = "" + linecache.cache[filename] = ( + len(repro_code), + None, + [f"{line}\n" for line in repro_code.splitlines()], + filename, + ) + + # Execute the repro script + namespace = {} + exec(compile(repro_code, filename, "exec"), namespace) + + # Call the generated helper and verify it runs successfully + helper = namespace["helion_repro_caller"] + repro_result = helper() + + # Verify the output + torch.testing.assert_close(repro_result, x + 1) + + linecache.cache.pop(filename, None) + finally: + if original is None: + os.environ.pop("HELION_PRINT_REPRO", None) + else: + os.environ["HELION_PRINT_REPRO"] = original + + +if __name__ == "__main__": + unittest.main() From 5968d9a8d2fb68a4ee6a0af7ac3db6c74f0fbf02 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 30 Oct 2025 14:03:37 -0700 Subject: [PATCH 2/2] up --- README.md | 5 +++ docs/api/config.md | 2 +- docs/api/kernel.md | 1 + docs/api/settings.md | 8 +++- docs/index.md | 5 +++ helion/runtime/kernel.py | 81 ++++++++++++++++++++++++++++++++++++++ helion/runtime/settings.py | 4 ++ 7 files changed, 104 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 91aac1e36..3aeb3a6f2 100644 --- a/README.md +++ b/README.md @@ -301,6 +301,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU helpful for debugging and understanding Helion's compilation process. One can also use `foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string. +To emit a repro script that includes the Helion kernel definition, the config decorator, and a +`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set +`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints +the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker. + Within an `hl.tile`/`hl.grid` device loop, if you want to print intermediate results using `print("x", ...)` syntax, or pause execution using Python's built-in `breakpoint()`, set either `TRITON_INTERPRET=1` (runs Triton's CPU interpreter) or `HELION_INTERPRET=1` (runs the Helion kernel in eager mode). diff --git a/docs/api/config.md b/docs/api/config.md index 188572822..046651eaa 100644 --- a/docs/api/config.md +++ b/docs/api/config.md @@ -27,7 +27,7 @@ The `Config` class represents kernel optimization parameters that control how He |--------|--------|----------| | **Purpose** | Control execution performance | Control compilation behavior | | **Autotuning** | ✅ Automatically optimized | ❌ Never autotuned | -| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `autotune_effort` | +| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `print_repro`, `autotune_effort` | | **When to use** | Performance optimization | Development, debugging, environment setup | diff --git a/docs/api/kernel.md b/docs/api/kernel.md index b8fe90c67..21c814bab 100644 --- a/docs/api/kernel.md +++ b/docs/api/kernel.md @@ -161,6 +161,7 @@ Settings control **how the kernel is compiled** and the development environment: autotune_effort="none", # Skip autotuning for development autotune_effort="quick", # Smaller autotuning budget when search is enabled print_output_code=True, # Debug: show generated Triton code + print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script static_shapes=True, # Compilation optimization strategy autotune_log_level=logging.DEBUG # Verbose autotuning output ) diff --git a/docs/api/settings.md b/docs/api/settings.md index 5f63d63a2..299728628 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -61,7 +61,8 @@ import helion.language as hl @helion.kernel( autotune_effort="none", # Skip autotuning - print_output_code=True, # Debug output + print_output_code=True, # Debug: show generated Triton code + print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script ) def my_kernel(x: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) @@ -190,6 +191,10 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b Print generated Triton code to stderr. Default is ``False``. Controlled by ``HELION_PRINT_OUTPUT_CODE=1``. +.. autoattribute:: Settings.print_repro + + Print Helion kernel code, config, and caller code to stderr as a standalone repro script. Default is ``False``. Controlled by ``HELION_PRINT_REPRO=1``. + .. autoattribute:: Settings.output_origin_lines Annotate generated Triton code with ``# src[:]`` comments indicating the originating Helion statements. @@ -259,6 +264,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe | ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. | | ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. | | ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. | +| ``HELION_PRINT_REPRO`` | ``print_repro`` | Print Helion kernel code, config, and caller code to stderr as a standalone repro script. | | ``HELION_OUTPUT_ORIGIN_LINES`` | ``output_origin_lines`` | Include ``# src[...]`` comments in generated Triton code; set to ``0`` to disable. | | ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. | | ``HELION_ALLOW_WARP_SPECIALIZE`` | ``allow_warp_specialize`` | Permit warp-specialized code generation for ``tl.range``. | diff --git a/docs/index.md b/docs/index.md index 911b106aa..f41bbbacb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -241,6 +241,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU helpful for debugging and understanding Helion's compilation process. One can also use `foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string. +To emit a repro script that includes the Helion kernel definition, the config decorator, and a +`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set +`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints +the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker. + To force autotuning, bypassing provided configurations, set `HELION_FORCE_AUTOTUNE=1` or invoke `foo_kernel.autotune(args, force=True)`. diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 44ea26e35..7bfd38d2d 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -9,6 +9,7 @@ import operator import re import sys +import textwrap import types from typing import TYPE_CHECKING from typing import Callable @@ -641,8 +642,88 @@ def __call__(self, *args: object) -> _R: self.format_kernel_decorator(self._config, self.settings) ] = 1 + if self.settings.print_repro: + self._print_repro(args) + return self._run(*args) + def _print_repro( + self, args: tuple[object, ...], config: Config | None = None + ) -> None: + effective_config = config or self._config + assert effective_config is not None + + # Get kernel source + try: + raw_source = inspect.getsource(self.kernel.fn) + source_lines = textwrap.dedent(raw_source).splitlines() + # Skip decorator lines (including multi-line decorators) + start_idx = 0 + while start_idx < len(source_lines) and not source_lines[ + start_idx + ].lstrip().startswith("def "): + start_idx += 1 + kernel_body = "\n".join(source_lines[start_idx:]) + except (OSError, TypeError): + kernel_body = f"# Source unavailable for {self.kernel.fn.__module__}.{self.kernel.fn.__qualname__}" + + # Format decorator + decorator = self.format_kernel_decorator(effective_config, self.settings) + + # Build output + output_lines = [ + "# === HELION KERNEL REPRO ===", + "import helion", + "import helion.language as hl", + "import torch", + "from torch._dynamo.testing import rand_strided", + "", + decorator, + kernel_body, + ] + + # Generate caller function + if args: + + def _render_input_arg_assignment(name: str, value: object) -> list[str]: + if isinstance(value, torch.Tensor): + shape = tuple(int(d) for d in value.shape) + stride = tuple(int(s) for s in value.stride()) + device = str(value.device) + dtype = str(value.dtype) + + lines = [ + f"{name} = rand_strided({shape!r}, {stride!r}, dtype={dtype}, device={device!r})" + ] + + if value.requires_grad: + lines.append(f"{name}.requires_grad_(True)") + return lines + + return [f"{name} = {value!r}"] + + sig_param_names = list(self.kernel.signature.parameters.keys()) + assert len(args) == len(sig_param_names) + + output_lines.extend(["", "def helion_repro_caller():"]) + output_lines.append(" torch.manual_seed(0)") + arg_names = [] + + for i, value in enumerate(args): + var_name = sig_param_names[i] + arg_names.append(var_name) + + # Add assignment lines with indentation + for line in _render_input_arg_assignment(var_name, value): + output_lines.append(f" {line}") + + # Add return statement + call_args = ", ".join(arg_names) + output_lines.append(f" return {self.kernel.name}({call_args})") + + output_lines.append("# === END HELION KERNEL REPRO ===") + print("\n".join(output_lines), file=sys.stderr) + class _KernelDecorator(Protocol): def __call__( diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index d43ccace4..54f329c7a 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -315,6 +315,9 @@ class _Settings: _env_get_bool, "HELION_PRINT_OUTPUT_CODE", False ) ) + print_repro: bool = dataclasses.field( + default_factory=functools.partial(_env_get_bool, "HELION_PRINT_REPRO", False) + ) output_origin_lines: bool = dataclasses.field( default_factory=functools.partial( _env_get_bool, "HELION_OUTPUT_ORIGIN_LINES", True @@ -386,6 +389,7 @@ class Settings(_Settings): "Set HELION_AUTOTUNE_IGNORE_ERRORS=1 to enable globally." ), "print_output_code": "If True, print the output code of the kernel to stderr.", + "print_repro": "If True, print Helion kernel code, config, and caller code to stderr as a standalone repro script.", "output_origin_lines": ( "If True, annotate generated Triton code with source-origin comments. " "Set HELION_OUTPUT_ORIGIN_LINES=0 to disable."