Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
baseline_config,
prefix=f"Generated Triton code for {decorator}:",
)
self.kernel.maybe_log_repro(self.log.error, new_args, baseline_config)
raise exc.InvalidConfig(
"Default config failed while computing baseline.\n"
f"Default config: {decorator}\n"
Expand Down Expand Up @@ -340,6 +341,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
return res
except Exception as e:
if match_unrecoverable_runtime_error(e):
self.kernel.maybe_log_repro(self.log.error, self.args, config)
raise exc.TritonUnrecoverableRuntimeError(
reason=str(e),
decorator=self.kernel.format_kernel_decorator(
Expand All @@ -358,6 +360,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
config,
prefix=f"Generated Triton code for {decorator}:",
)
self.kernel.maybe_log_repro(self.log.error, self.args, config)
raise exc.TritonError(
error=f"{type(e).__qualname__}: {e}",
decorator=decorator,
Expand All @@ -372,6 +375,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
prefix=f"Generated Triton code for {decorator}:",
)
self.log.warning(format_triton_compile_failure(config, e, self.kernel))
self.kernel.maybe_log_repro(self.log.warning, self.args, config)
else:
decorator = self.kernel.format_kernel_decorator(config, self.settings)
log_generated_triton_code_debug(
Expand All @@ -381,6 +385,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
prefix=f"Generated Triton code for {decorator}:",
)
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
self.kernel.maybe_log_repro(self.log.debug, self.args, config)
return inf

def start_precompile_and_check_for_hangs(
Expand Down Expand Up @@ -1198,6 +1203,9 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
self.config,
prefix=f"Generated Triton code for {decorator}:",
)
self.search.kernel.maybe_log_repro(
self.search.log.error, self.search.args, self.config
)
raise exc.TritonError(
error=f"{type(exc_obj).__qualname__}: {exc_obj}",
decorator=decorator,
Expand All @@ -1223,8 +1231,14 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
)
if classification == "warn":
self.search.log.warning(formatted)
self.search.kernel.maybe_log_repro(
self.search.log.warning, self.search.args, self.config
)
elif not ignore_errors:
self.search.log.debug(formatted)
self.search.kernel.maybe_log_repro(
self.search.log.debug, self.search.args, self.config
)
self._remote_error_handled = True


Expand Down
3 changes: 3 additions & 0 deletions helion/autotuner/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def __call__(
if level >= self.level:
self._logger.log(level, " ".join(map(_maybe_call, msg)))

def error(self, *msg: str | Callable[[], str]) -> None:
return self(*msg, level=logging.ERROR)

def warning(self, *msg: str | Callable[[], str]) -> None:
return self(*msg, level=logging.WARNING)

Expand Down
17 changes: 12 additions & 5 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,14 +645,19 @@ def __call__(self, *args: object) -> _R:
self.format_kernel_decorator(self._config, self.settings)
] = 1

if self.settings.print_repro:
self._print_repro(args)
self.maybe_log_repro(log.warning, args)

return self._run(*args)

def _print_repro(
self, args: tuple[object, ...], config: Config | None = None
def maybe_log_repro(
self,
log_func: Callable[[str], None],
args: Sequence[object],
config: Config | None = None,
) -> None:
if not self.settings.print_repro:
return

effective_config = config or self._config
assert effective_config is not None

Expand Down Expand Up @@ -723,9 +728,11 @@ def _render_input_arg_assignment(name: str, value: object) -> list[str]:
# Add return statement
call_args = ", ".join(arg_names)
output_lines.append(f" return {self.kernel.name}({call_args})")
output_lines.extend(["", "helion_repro_caller()"])

output_lines.append("# === END HELION KERNEL REPRO ===")
print("\n".join(output_lines), file=sys.stderr)
repro_text = "\n".join(output_lines)
log_func(repro_text)


class _KernelDecorator(Protocol):
Expand Down
1 change: 1 addition & 0 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _make_search(
search.kernel = SimpleNamespace(
format_kernel_decorator=lambda config, s: "decorator",
to_triton_code=lambda config: "code",
maybe_log_repro=lambda log_func, args, config=None: None,
)
search.args = args
search.counters = collections.Counter()
Expand Down
16 changes: 9 additions & 7 deletions test/test_debug_utils.expected
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
def kernel1(x: torch.Tensor) -> torch.Tensor:
@helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
def kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
m, n = x.shape
for tile_m, tile_n in hl.tile([m, n]):
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
n = x.shape[0]
for tile_n in hl.tile([n]):
out[tile_n] = x[tile_n] + 1
return out

def helion_repro_caller():
torch.manual_seed(0)
x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE)
return kernel1(x)
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
return kernel(x)

helion_repro_caller()
# === END HELION KERNEL REPRO ===
172 changes: 111 additions & 61 deletions test/test_debug_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import contextlib
import linecache
import os
import unittest
from unittest import mock

import pytest
import torch
Expand All @@ -24,65 +26,80 @@ def _store_capfd_on_class(request, capfd):
request.cls._capfd = capfd


@pytest.fixture(autouse=True)
def _store_caplog_on_class(request, caplog):
"""
Expose pytest's caplog fixture as `self._caplog` inside the TestDebugUtils class
(works for unittest.TestCase-style tests).
"""
if request.cls is not None:
request.cls._caplog = caplog


class TestDebugUtils(RefEagerTestDisabled, TestCase):
def test_print_repro_env_var(self):
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
@contextlib.contextmanager
def _with_print_repro_enabled(self):
"""Context manager to temporarily set HELION_PRINT_REPRO=1."""
original = os.environ.get("HELION_PRINT_REPRO")
os.environ["HELION_PRINT_REPRO"] = "1"
try:
yield
finally:
if original is None:
os.environ.pop("HELION_PRINT_REPRO", None)
else:
os.environ["HELION_PRINT_REPRO"] = original

def _clear_captures(self):
"""Clear pytest capture fixtures if available."""
if hasattr(self, "_capfd"):
self._capfd.readouterr()
if hasattr(self, "_caplog"):
self._caplog.clear()

def _create_kernel(self, **kwargs):
"""Create a simple 1D kernel for testing.

Args:
**kwargs: Arguments to pass to @helion.kernel decorator.
"""

@helion.kernel(
config=helion.Config(
block_sizes=[2, 2],
flatten_loops=[False],
indexing=["pointer", "pointer"],
l2_groupings=[1],
load_eviction_policies=[""],
loop_orders=[[0, 1]],
num_stages=1,
num_warps=4,
pid_type="flat",
range_flattens=[None],
range_multi_buffers=[None],
range_num_stages=[0],
range_unroll_factors=[0],
),
@helion.kernel(**kwargs)
def kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.shape[0]
for tile_n in hl.tile([n]):
out[tile_n] = x[tile_n] + 1
return out

return kernel

def test_print_repro_env_var(self):
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
with self._with_print_repro_enabled():
kernel = self._create_kernel(
config=helion.Config(block_sizes=[32], num_warps=4),
static_shapes=True,
)
def kernel1(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
m, n = x.shape
for tile_m, tile_n in hl.tile([m, n]):
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
return out

torch.manual_seed(0)
x = torch.randn([2, 2], dtype=torch.float32, device=DEVICE)
x = torch.randn([128], dtype=torch.float32, device=DEVICE)

if hasattr(self, "_capfd"):
self._capfd.readouterr()
self._clear_captures()

result = kernel1(x)
result = kernel(x)
torch.testing.assert_close(result, x + 1)

if not hasattr(self, "_capfd"):
return # Cannot test without capture

captured = "".join(self._capfd.readouterr())
# Extract repro script from logs (use records to get the raw message without formatting)
assert hasattr(self, "_caplog"), "caplog fixture not available"
repro_script = None
for record in self._caplog.records:
if "# === HELION KERNEL REPRO ===" in record.message:
repro_script = record.message
break

# Extract repro script
lines = captured.splitlines()
start = next(
i
for i, line in enumerate(lines)
if "# === HELION KERNEL REPRO ===" in line
)
end = next(
i
for i, line in enumerate(lines[start:], start)
if "# === END HELION KERNEL REPRO ===" in line
)
repro_script = "\n".join(lines[start : end + 1])
if repro_script is None:
self.fail("No repro script found in logs")

# Normalize range_warp_specializes=[None] to [] for comparison
normalized_script = repro_script.replace(
Expand All @@ -92,26 +109,18 @@ def kernel1(x: torch.Tensor) -> torch.Tensor:
# Verify repro script matches expected script
self.assertExpectedJournal(normalized_script)

# Extract the actual code (without the comment markers) for execution
repro_lines = repro_script.splitlines()
code_start = 1 if repro_lines[0].startswith("# === HELION") else 0
code_end = len(repro_lines) - (
1 if repro_lines[-1].startswith("# === END") else 0
)
repro_code = "\n".join(repro_lines[code_start:code_end])

# Setup linecache so inspect.getsource() works on exec'd code
filename = "<helion_repro_test>"
linecache.cache[filename] = (
len(repro_code),
len(repro_script),
None,
[f"{line}\n" for line in repro_code.splitlines()],
[f"{line}\n" for line in repro_script.splitlines()],
filename,
)

# Execute the repro script
namespace = {}
exec(compile(repro_code, filename, "exec"), namespace)
exec(compile(repro_script, filename, "exec"), namespace)

# Call the generated helper and verify it runs successfully
helper = namespace["helion_repro_caller"]
Expand All @@ -121,11 +130,52 @@ def kernel1(x: torch.Tensor) -> torch.Tensor:
torch.testing.assert_close(repro_result, x + 1)

linecache.cache.pop(filename, None)
finally:
if original is None:
os.environ.pop("HELION_PRINT_REPRO", None)
else:
os.environ["HELION_PRINT_REPRO"] = original

def test_print_repro_on_autotune_error(self):
"""Ensure HELION_PRINT_REPRO=1 prints repro when configs fail during autotuning.

This test mocks do_bench to fail on the second config, guaranteeing the repro
printing code path is exercised for "warn" level errors.
"""
with self._with_print_repro_enabled():
kernel = self._create_kernel(
configs=[
helion.Config(block_sizes=[32], num_warps=4),
helion.Config(block_sizes=[64], num_warps=8),
],
autotune_precompile=False,
)

torch.manual_seed(0)
x = torch.randn([128], dtype=torch.float32, device=DEVICE)

self._clear_captures()

# Mock do_bench to fail on the second config with PTXASError (warn level)
from torch._inductor.runtime.triton_compat import PTXASError
from triton.testing import do_bench as original_do_bench

call_count = [0]

def mock_do_bench(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 2: # Fail on second config
raise PTXASError("Mocked PTXAS error")
return original_do_bench(*args, **kwargs)

with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
# Autotune will try both configs, second one will fail and print repro
kernel.autotune([x], force=False)

# Extract repro script from stderr
assert hasattr(self, "_capfd"), "capfd fixture not available"
captured = "".join(self._capfd.readouterr())

# Verify that a repro script was printed for the failing config
self.assertIn("# === HELION KERNEL REPRO ===", captured)
self.assertIn("# === END HELION KERNEL REPRO ===", captured)
self.assertIn("kernel", captured)
self.assertIn("helion_repro_caller()", captured)


if __name__ == "__main__":
Expand Down
Loading