Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,14 @@ def __init__(
_maybe_skip_dtype_check_in_meta_registrations(),
patch_inductor_lowerings(),
):
self.host_function: HostFunction = HostFunction(
self.kernel.fn, self.fake_args, constexpr_args
)
try:
self.host_function: HostFunction = HostFunction(
self.kernel.fn, self.fake_args, constexpr_args
)
except Exception:
config = self.env.config_spec.default_config()
self.maybe_log_repro(log.warning, args, config=config)
raise

@property
def settings(self) -> Settings:
Expand Down Expand Up @@ -456,6 +461,7 @@ def compile_config(
self.format_kernel_decorator(config, self.settings),
exc_info=True,
)
self.maybe_log_repro(log.warning, self.fake_args, config=config)
raise
if allow_print:
log.info("Output code: \n%s", triton_code)
Expand Down
90 changes: 90 additions & 0 deletions test/test_debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,96 @@ def mock_do_bench(*args, **kwargs):
self.assertIn("kernel", captured)
self.assertIn("helion_repro_caller()", captured)

def test_print_repro_on_device_ir_lowering_error(self):
"""Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering."""
with self._with_print_repro_enabled():

@helion.kernel(config=helion.Config(block_sizes=[32], num_warps=4))
def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.shape[0]
for tile_n in hl.tile([n]):
# Using torch.nonzero inside device loop causes compilation error
# because it produces data-dependent output shape
torch.nonzero(x[tile_n])
out[tile_n] = x[tile_n]
return out

torch.manual_seed(0)
x = torch.randn([128], dtype=torch.float32, device=DEVICE)

with self.capture_logs() as log_capture:
# This should trigger a compilation error during device IR lowering
with self.assertRaises(RuntimeError):
kernel_with_compile_error(x)

# Extract repro script from logs
repro_script = None
for record in log_capture.records:
if "# === HELION KERNEL REPRO ===" in record.message:
repro_script = record.message
break

# Verify that a repro script was printed when compilation failed
self.assertIsNotNone(
repro_script,
"Expected repro script to be printed when device IR lowering fails",
)
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
self.assertIn("kernel_with_compile_error", repro_script)
self.assertIn("helion_repro_caller()", repro_script)

def test_print_repro_on_triton_codegen_error(self):
"""Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails."""
with self._with_print_repro_enabled():

@helion.kernel(config=helion.Config(block_sizes=[32], num_warps=4))
def kernel_with_triton_error(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.shape[0]
for tile_n in hl.tile([n]):
out[tile_n] = x[tile_n] + 1
return out

torch.manual_seed(0)
x = torch.randn([128], dtype=torch.float32, device=DEVICE)

# Mock PyCodeCache.load to simulate a Triton codegen error
from torch._inductor.codecache import PyCodeCache

original_load = PyCodeCache.load

def mock_load(code, *args, **kwargs):
if "kernel_with_triton_error" in code:
raise RuntimeError("Simulated Triton codegen error")
return original_load(code, *args, **kwargs)

with (
self.capture_logs() as log_capture,
mock.patch.object(PyCodeCache, "load", mock_load),
):
# This should trigger a Triton codegen error
with self.assertRaises(RuntimeError):
kernel_with_triton_error(x)

# Extract repro script from logs
repro_script = None
for record in log_capture.records:
if "# === HELION KERNEL REPRO ===" in record.message:
repro_script = record.message
break

# Verify that a repro script was printed when Triton codegen failed
self.assertIsNotNone(
repro_script,
"Expected repro script to be printed when Triton codegen fails",
)
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
self.assertIn("kernel_with_triton_error", repro_script)
self.assertIn("helion_repro_caller()", repro_script)


if __name__ == "__main__":
unittest.main()
Loading