diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index afcd7108c..3ae7db801 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -356,9 +356,14 @@ def __init__( _maybe_skip_dtype_check_in_meta_registrations(), patch_inductor_lowerings(), ): - self.host_function: HostFunction = HostFunction( - self.kernel.fn, self.fake_args, constexpr_args - ) + try: + self.host_function: HostFunction = HostFunction( + self.kernel.fn, self.fake_args, constexpr_args + ) + except Exception: + config = self.env.config_spec.default_config() + self.maybe_log_repro(log.warning, args, config=config) + raise @property def settings(self) -> Settings: @@ -456,6 +461,7 @@ def compile_config( self.format_kernel_decorator(config, self.settings), exc_info=True, ) + self.maybe_log_repro(log.warning, self.fake_args, config=config) raise if allow_print: log.info("Output code: \n%s", triton_code) diff --git a/test/test_debug_utils.py b/test/test_debug_utils.py index b6e243f1d..087d2ffd2 100644 --- a/test/test_debug_utils.py +++ b/test/test_debug_utils.py @@ -168,6 +168,96 @@ def mock_do_bench(*args, **kwargs): self.assertIn("kernel", captured) self.assertIn("helion_repro_caller()", captured) + def test_print_repro_on_device_ir_lowering_error(self): + """Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering.""" + with self._with_print_repro_enabled(): + + @helion.kernel(config=helion.Config(block_sizes=[32], num_warps=4)) + def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.shape[0] + for tile_n in hl.tile([n]): + # Using torch.nonzero inside device loop causes compilation error + # because it produces data-dependent output shape + torch.nonzero(x[tile_n]) + out[tile_n] = x[tile_n] + return out + + torch.manual_seed(0) + x = torch.randn([128], dtype=torch.float32, device=DEVICE) + + with self.capture_logs() as log_capture: + # This should trigger a compilation error during device IR lowering + with self.assertRaises(RuntimeError): + kernel_with_compile_error(x) + + # Extract repro script from logs + repro_script = None + for record in log_capture.records: + if "# === HELION KERNEL REPRO ===" in record.message: + repro_script = record.message + break + + # Verify that a repro script was printed when compilation failed + self.assertIsNotNone( + repro_script, + "Expected repro script to be printed when device IR lowering fails", + ) + self.assertIn("# === HELION KERNEL REPRO ===", repro_script) + self.assertIn("# === END HELION KERNEL REPRO ===", repro_script) + self.assertIn("kernel_with_compile_error", repro_script) + self.assertIn("helion_repro_caller()", repro_script) + + def test_print_repro_on_triton_codegen_error(self): + """Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails.""" + with self._with_print_repro_enabled(): + + @helion.kernel(config=helion.Config(block_sizes=[32], num_warps=4)) + def kernel_with_triton_error(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.shape[0] + for tile_n in hl.tile([n]): + out[tile_n] = x[tile_n] + 1 + return out + + torch.manual_seed(0) + x = torch.randn([128], dtype=torch.float32, device=DEVICE) + + # Mock PyCodeCache.load to simulate a Triton codegen error + from torch._inductor.codecache import PyCodeCache + + original_load = PyCodeCache.load + + def mock_load(code, *args, **kwargs): + if "kernel_with_triton_error" in code: + raise RuntimeError("Simulated Triton codegen error") + return original_load(code, *args, **kwargs) + + with ( + self.capture_logs() as log_capture, + mock.patch.object(PyCodeCache, "load", mock_load), + ): + # This should trigger a Triton codegen error + with self.assertRaises(RuntimeError): + kernel_with_triton_error(x) + + # Extract repro script from logs + repro_script = None + for record in log_capture.records: + if "# === HELION KERNEL REPRO ===" in record.message: + repro_script = record.message + break + + # Verify that a repro script was printed when Triton codegen failed + self.assertIsNotNone( + repro_script, + "Expected repro script to be printed when Triton codegen fails", + ) + self.assertIn("# === HELION KERNEL REPRO ===", repro_script) + self.assertIn("# === END HELION KERNEL REPRO ===", repro_script) + self.assertIn("kernel_with_triton_error", repro_script) + self.assertIn("helion_repro_caller()", repro_script) + if __name__ == "__main__": unittest.main()