From 3600ff8888d2a17dd35e3aae3f33b2cdf3a3379b Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 31 Oct 2025 16:06:02 -0700 Subject: [PATCH] Make HELION_PRINT_REPRO=1 take effect in more error cases --- helion/autotuner/base_search.py | 14 +++ helion/autotuner/logger.py | 3 + helion/runtime/kernel.py | 17 +++- test/test_autotuner.py | 1 + test/test_debug_utils.expected | 16 +-- test/test_debug_utils.py | 172 +++++++++++++++++++++----------- 6 files changed, 150 insertions(+), 73 deletions(-) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index a50db185e..a9e841f68 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -188,6 +188,7 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: baseline_config, prefix=f"Generated Triton code for {decorator}:", ) + self.kernel.maybe_log_repro(self.log.error, new_args, baseline_config) raise exc.InvalidConfig( "Default config failed while computing baseline.\n" f"Default config: {decorator}\n" @@ -340,6 +341,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: return res except Exception as e: if match_unrecoverable_runtime_error(e): + self.kernel.maybe_log_repro(self.log.error, self.args, config) raise exc.TritonUnrecoverableRuntimeError( reason=str(e), decorator=self.kernel.format_kernel_decorator( @@ -358,6 +360,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: config, prefix=f"Generated Triton code for {decorator}:", ) + self.kernel.maybe_log_repro(self.log.error, self.args, config) raise exc.TritonError( error=f"{type(e).__qualname__}: {e}", decorator=decorator, @@ -372,6 +375,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: prefix=f"Generated Triton code for {decorator}:", ) self.log.warning(format_triton_compile_failure(config, e, self.kernel)) + self.kernel.maybe_log_repro(self.log.warning, self.args, config) else: decorator = self.kernel.format_kernel_decorator(config, self.settings) log_generated_triton_code_debug( @@ -381,6 +385,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: prefix=f"Generated Triton code for {decorator}:", ) self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}") + self.kernel.maybe_log_repro(self.log.debug, self.args, config) return inf def start_precompile_and_check_for_hangs( @@ -1198,6 +1203,9 @@ def _consume_result(self, *, raise_on_raise: bool) -> None: self.config, prefix=f"Generated Triton code for {decorator}:", ) + self.search.kernel.maybe_log_repro( + self.search.log.error, self.search.args, self.config + ) raise exc.TritonError( error=f"{type(exc_obj).__qualname__}: {exc_obj}", decorator=decorator, @@ -1223,8 +1231,14 @@ def _consume_result(self, *, raise_on_raise: bool) -> None: ) if classification == "warn": self.search.log.warning(formatted) + self.search.kernel.maybe_log_repro( + self.search.log.warning, self.search.args, self.config + ) elif not ignore_errors: self.search.log.debug(formatted) + self.search.kernel.maybe_log_repro( + self.search.log.debug, self.search.args, self.config + ) self._remote_error_handled = True diff --git a/helion/autotuner/logger.py b/helion/autotuner/logger.py index f2edd147b..e56542911 100644 --- a/helion/autotuner/logger.py +++ b/helion/autotuner/logger.py @@ -58,6 +58,9 @@ def __call__( if level >= self.level: self._logger.log(level, " ".join(map(_maybe_call, msg))) + def error(self, *msg: str | Callable[[], str]) -> None: + return self(*msg, level=logging.ERROR) + def warning(self, *msg: str | Callable[[], str]) -> None: return self(*msg, level=logging.WARNING) diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 46157500d..afcd7108c 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -645,14 +645,19 @@ def __call__(self, *args: object) -> _R: self.format_kernel_decorator(self._config, self.settings) ] = 1 - if self.settings.print_repro: - self._print_repro(args) + self.maybe_log_repro(log.warning, args) return self._run(*args) - def _print_repro( - self, args: tuple[object, ...], config: Config | None = None + def maybe_log_repro( + self, + log_func: Callable[[str], None], + args: Sequence[object], + config: Config | None = None, ) -> None: + if not self.settings.print_repro: + return + effective_config = config or self._config assert effective_config is not None @@ -723,9 +728,11 @@ def _render_input_arg_assignment(name: str, value: object) -> list[str]: # Add return statement call_args = ", ".join(arg_names) output_lines.append(f" return {self.kernel.name}({call_args})") + output_lines.extend(["", "helion_repro_caller()"]) output_lines.append("# === END HELION KERNEL REPRO ===") - print("\n".join(output_lines), file=sys.stderr) + repro_text = "\n".join(output_lines) + log_func(repro_text) class _KernelDecorator(Protocol): diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 877577a2d..908fb4dae 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -82,6 +82,7 @@ def _make_search( search.kernel = SimpleNamespace( format_kernel_decorator=lambda config, s: "decorator", to_triton_code=lambda config: "code", + maybe_log_repro=lambda log_func, args, config=None: None, ) search.args = args search.counters = collections.Counter() diff --git a/test/test_debug_utils.expected b/test/test_debug_utils.expected index 93057aad2..52a62565e 100644 --- a/test/test_debug_utils.expected +++ b/test/test_debug_utils.expected @@ -8,16 +8,18 @@ import helion.language as hl import torch from torch._dynamo.testing import rand_strided -@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) -def kernel1(x: torch.Tensor) -> torch.Tensor: +@helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) +def kernel(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) - m, n = x.shape - for tile_m, tile_n in hl.tile([m, n]): - out[tile_m, tile_n] = x[tile_m, tile_n] + 1 + n = x.shape[0] + for tile_n in hl.tile([n]): + out[tile_n] = x[tile_n] + 1 return out def helion_repro_caller(): torch.manual_seed(0) - x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE) - return kernel1(x) + x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE) + return kernel(x) + +helion_repro_caller() # === END HELION KERNEL REPRO === diff --git a/test/test_debug_utils.py b/test/test_debug_utils.py index 53628db71..a2d25f51a 100644 --- a/test/test_debug_utils.py +++ b/test/test_debug_utils.py @@ -1,8 +1,10 @@ from __future__ import annotations +import contextlib import linecache import os import unittest +from unittest import mock import pytest import torch @@ -24,65 +26,80 @@ def _store_capfd_on_class(request, capfd): request.cls._capfd = capfd +@pytest.fixture(autouse=True) +def _store_caplog_on_class(request, caplog): + """ + Expose pytest's caplog fixture as `self._caplog` inside the TestDebugUtils class + (works for unittest.TestCase-style tests). + """ + if request.cls is not None: + request.cls._caplog = caplog + + class TestDebugUtils(RefEagerTestDisabled, TestCase): - def test_print_repro_env_var(self): - """Ensure HELION_PRINT_REPRO=1 emits an executable repro script.""" + @contextlib.contextmanager + def _with_print_repro_enabled(self): + """Context manager to temporarily set HELION_PRINT_REPRO=1.""" original = os.environ.get("HELION_PRINT_REPRO") os.environ["HELION_PRINT_REPRO"] = "1" try: + yield + finally: + if original is None: + os.environ.pop("HELION_PRINT_REPRO", None) + else: + os.environ["HELION_PRINT_REPRO"] = original + + def _clear_captures(self): + """Clear pytest capture fixtures if available.""" + if hasattr(self, "_capfd"): + self._capfd.readouterr() + if hasattr(self, "_caplog"): + self._caplog.clear() + + def _create_kernel(self, **kwargs): + """Create a simple 1D kernel for testing. + + Args: + **kwargs: Arguments to pass to @helion.kernel decorator. + """ - @helion.kernel( - config=helion.Config( - block_sizes=[2, 2], - flatten_loops=[False], - indexing=["pointer", "pointer"], - l2_groupings=[1], - load_eviction_policies=[""], - loop_orders=[[0, 1]], - num_stages=1, - num_warps=4, - pid_type="flat", - range_flattens=[None], - range_multi_buffers=[None], - range_num_stages=[0], - range_unroll_factors=[0], - ), + @helion.kernel(**kwargs) + def kernel(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.shape[0] + for tile_n in hl.tile([n]): + out[tile_n] = x[tile_n] + 1 + return out + + return kernel + + def test_print_repro_env_var(self): + """Ensure HELION_PRINT_REPRO=1 emits an executable repro script.""" + with self._with_print_repro_enabled(): + kernel = self._create_kernel( + config=helion.Config(block_sizes=[32], num_warps=4), static_shapes=True, ) - def kernel1(x: torch.Tensor) -> torch.Tensor: - out = torch.empty_like(x) - m, n = x.shape - for tile_m, tile_n in hl.tile([m, n]): - out[tile_m, tile_n] = x[tile_m, tile_n] + 1 - return out torch.manual_seed(0) - x = torch.randn([2, 2], dtype=torch.float32, device=DEVICE) + x = torch.randn([128], dtype=torch.float32, device=DEVICE) - if hasattr(self, "_capfd"): - self._capfd.readouterr() + self._clear_captures() - result = kernel1(x) + result = kernel(x) torch.testing.assert_close(result, x + 1) - if not hasattr(self, "_capfd"): - return # Cannot test without capture - - captured = "".join(self._capfd.readouterr()) + # Extract repro script from logs (use records to get the raw message without formatting) + assert hasattr(self, "_caplog"), "caplog fixture not available" + repro_script = None + for record in self._caplog.records: + if "# === HELION KERNEL REPRO ===" in record.message: + repro_script = record.message + break - # Extract repro script - lines = captured.splitlines() - start = next( - i - for i, line in enumerate(lines) - if "# === HELION KERNEL REPRO ===" in line - ) - end = next( - i - for i, line in enumerate(lines[start:], start) - if "# === END HELION KERNEL REPRO ===" in line - ) - repro_script = "\n".join(lines[start : end + 1]) + if repro_script is None: + self.fail("No repro script found in logs") # Normalize range_warp_specializes=[None] to [] for comparison normalized_script = repro_script.replace( @@ -92,26 +109,18 @@ def kernel1(x: torch.Tensor) -> torch.Tensor: # Verify repro script matches expected script self.assertExpectedJournal(normalized_script) - # Extract the actual code (without the comment markers) for execution - repro_lines = repro_script.splitlines() - code_start = 1 if repro_lines[0].startswith("# === HELION") else 0 - code_end = len(repro_lines) - ( - 1 if repro_lines[-1].startswith("# === END") else 0 - ) - repro_code = "\n".join(repro_lines[code_start:code_end]) - # Setup linecache so inspect.getsource() works on exec'd code filename = "" linecache.cache[filename] = ( - len(repro_code), + len(repro_script), None, - [f"{line}\n" for line in repro_code.splitlines()], + [f"{line}\n" for line in repro_script.splitlines()], filename, ) # Execute the repro script namespace = {} - exec(compile(repro_code, filename, "exec"), namespace) + exec(compile(repro_script, filename, "exec"), namespace) # Call the generated helper and verify it runs successfully helper = namespace["helion_repro_caller"] @@ -121,11 +130,52 @@ def kernel1(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(repro_result, x + 1) linecache.cache.pop(filename, None) - finally: - if original is None: - os.environ.pop("HELION_PRINT_REPRO", None) - else: - os.environ["HELION_PRINT_REPRO"] = original + + def test_print_repro_on_autotune_error(self): + """Ensure HELION_PRINT_REPRO=1 prints repro when configs fail during autotuning. + + This test mocks do_bench to fail on the second config, guaranteeing the repro + printing code path is exercised for "warn" level errors. + """ + with self._with_print_repro_enabled(): + kernel = self._create_kernel( + configs=[ + helion.Config(block_sizes=[32], num_warps=4), + helion.Config(block_sizes=[64], num_warps=8), + ], + autotune_precompile=False, + ) + + torch.manual_seed(0) + x = torch.randn([128], dtype=torch.float32, device=DEVICE) + + self._clear_captures() + + # Mock do_bench to fail on the second config with PTXASError (warn level) + from torch._inductor.runtime.triton_compat import PTXASError + from triton.testing import do_bench as original_do_bench + + call_count = [0] + + def mock_do_bench(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 2: # Fail on second config + raise PTXASError("Mocked PTXAS error") + return original_do_bench(*args, **kwargs) + + with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench): + # Autotune will try both configs, second one will fail and print repro + kernel.autotune([x], force=False) + + # Extract repro script from stderr + assert hasattr(self, "_capfd"), "capfd fixture not available" + captured = "".join(self._capfd.readouterr()) + + # Verify that a repro script was printed for the failing config + self.assertIn("# === HELION KERNEL REPRO ===", captured) + self.assertIn("# === END HELION KERNEL REPRO ===", captured) + self.assertIn("kernel", captured) + self.assertIn("helion_repro_caller()", captured) if __name__ == "__main__":