diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 3ae7db801..f2a853096 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -737,7 +737,7 @@ def _render_input_arg_assignment(name: str, value: object) -> list[str]: output_lines.extend(["", "helion_repro_caller()"]) output_lines.append("# === END HELION KERNEL REPRO ===") - repro_text = "\n".join(output_lines) + repro_text = "\n" + "\n".join(output_lines) log_func(repro_text) diff --git a/test/test_debug_utils.expected b/test/test_debug_utils.expected index 52a62565e..a3f55ec94 100644 --- a/test/test_debug_utils.expected +++ b/test/test_debug_utils.expected @@ -23,3 +23,75 @@ def helion_repro_caller(): helion_repro_caller() # === END HELION KERNEL REPRO === + +--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_autotune_error) +# === HELION KERNEL REPRO === +import helion +import helion.language as hl +import torch +from torch._dynamo.testing import rand_strided + +@helion.kernel(config=helion.Config(block_sizes=[64], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=8, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) +def kernel(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.shape[0] + for tile_n in hl.tile([n]): + out[tile_n] = x[tile_n] + 1 + return out + +def helion_repro_caller(): + torch.manual_seed(0) + x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE) + return kernel(x) + +helion_repro_caller() +# === END HELION KERNEL REPRO === + +--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_device_ir_lowering_error) +# === HELION KERNEL REPRO === +import helion +import helion.language as hl +import torch +from torch._dynamo.testing import rand_strided + +@helion.kernel(config=helion.Config(block_sizes=[32], indexing=[], load_eviction_policies=[], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) +def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.shape[0] + for tile_n in hl.tile([n]): + # Using torch.nonzero inside device loop causes compilation error + # because it produces data-dependent output shape + torch.nonzero(x[tile_n]) + out[tile_n] = x[tile_n] + return out + +def helion_repro_caller(): + torch.manual_seed(0) + x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE) + return kernel_with_compile_error(x) + +helion_repro_caller() +# === END HELION KERNEL REPRO === + +--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_triton_codegen_error) +# === HELION KERNEL REPRO === +import helion +import helion.language as hl +import torch +from torch._dynamo.testing import rand_strided + +@helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) +def kernel_with_triton_error(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.shape[0] + for tile_n in hl.tile([n]): + out[tile_n] = x[tile_n] + 1 + return out + +def helion_repro_caller(): + torch.manual_seed(0) + x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE) + return kernel_with_triton_error(x) + +helion_repro_caller() +# === END HELION KERNEL REPRO === diff --git a/test/test_debug_utils.py b/test/test_debug_utils.py index 087d2ffd2..cb0ca65ec 100644 --- a/test/test_debug_utils.py +++ b/test/test_debug_utils.py @@ -68,6 +68,41 @@ def kernel(x: torch.Tensor) -> torch.Tensor: return kernel + def _extract_repro_script(self, text: str) -> str: + """Extract the repro code block between markers (including markers). + + Args: + text: The text containing the repro block. Can be a full string or log_capture object. + + Returns: + The extracted repro block including both markers. + """ + # If it's a log capture object, extract the repro script from logs first + if hasattr(text, "records"): + log_capture = text + repro_script = None + for record in log_capture.records: + if "# === HELION KERNEL REPRO ===" in record.message: + repro_script = record.message + break + if repro_script is None: + self.fail("No repro script found in logs") + text = repro_script + + # Extract code block between markers + start_marker = "# === HELION KERNEL REPRO ===" + end_marker = "# === END HELION KERNEL REPRO ===" + start_idx = text.find(start_marker) + end_idx = text.find(end_marker) + + if start_idx == -1: + self.fail("Start marker not found") + if end_idx == -1: + self.fail("End marker not found") + + # Extract content including both markers + return text[start_idx : end_idx + len(end_marker)].strip() + def test_print_repro_env_var(self): """Ensure HELION_PRINT_REPRO=1 emits an executable repro script.""" with self._with_print_repro_enabled(): @@ -83,15 +118,8 @@ def test_print_repro_env_var(self): result = kernel(x) torch.testing.assert_close(result, x + 1) - # Extract repro script from logs (use records to get the raw message without formatting) - repro_script = None - for record in log_capture.records: - if "# === HELION KERNEL REPRO ===" in record.message: - repro_script = record.message - break - - if repro_script is None: - self.fail("No repro script found in logs") + # Extract repro script from logs + repro_script = self._extract_repro_script(log_capture) # Normalize range_warp_specializes=[None] to [] for comparison normalized_script = repro_script.replace( @@ -163,10 +191,14 @@ def mock_do_bench(*args, **kwargs): captured = "".join(output_capture.readouterr()) # Verify that a repro script was printed for the failing config - self.assertIn("# === HELION KERNEL REPRO ===", captured) - self.assertIn("# === END HELION KERNEL REPRO ===", captured) - self.assertIn("kernel", captured) - self.assertIn("helion_repro_caller()", captured) + repro_script = self._extract_repro_script(captured) + + # Normalize range_warp_specializes=[None] to [] for comparison + normalized_script = repro_script.replace( + "range_warp_specializes=[None]", "range_warp_specializes=[]" + ) + + self.assertExpectedJournal(normalized_script) def test_print_repro_on_device_ir_lowering_error(self): """Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering.""" @@ -192,21 +224,14 @@ def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor: kernel_with_compile_error(x) # Extract repro script from logs - repro_script = None - for record in log_capture.records: - if "# === HELION KERNEL REPRO ===" in record.message: - repro_script = record.message - break - - # Verify that a repro script was printed when compilation failed - self.assertIsNotNone( - repro_script, - "Expected repro script to be printed when device IR lowering fails", + repro_script = self._extract_repro_script(log_capture) + + # Normalize range_warp_specializes=[None] to [] for comparison + normalized_script = repro_script.replace( + "range_warp_specializes=[None]", "range_warp_specializes=[]" ) - self.assertIn("# === HELION KERNEL REPRO ===", repro_script) - self.assertIn("# === END HELION KERNEL REPRO ===", repro_script) - self.assertIn("kernel_with_compile_error", repro_script) - self.assertIn("helion_repro_caller()", repro_script) + + self.assertExpectedJournal(normalized_script) def test_print_repro_on_triton_codegen_error(self): """Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails.""" @@ -242,21 +267,14 @@ def mock_load(code, *args, **kwargs): kernel_with_triton_error(x) # Extract repro script from logs - repro_script = None - for record in log_capture.records: - if "# === HELION KERNEL REPRO ===" in record.message: - repro_script = record.message - break - - # Verify that a repro script was printed when Triton codegen failed - self.assertIsNotNone( - repro_script, - "Expected repro script to be printed when Triton codegen fails", + repro_script = self._extract_repro_script(log_capture) + + # Normalize range_warp_specializes=[None] to [] for comparison + normalized_script = repro_script.replace( + "range_warp_specializes=[None]", "range_warp_specializes=[]" ) - self.assertIn("# === HELION KERNEL REPRO ===", repro_script) - self.assertIn("# === END HELION KERNEL REPRO ===", repro_script) - self.assertIn("kernel_with_triton_error", repro_script) - self.assertIn("helion_repro_caller()", repro_script) + + self.assertExpectedJournal(normalized_script) if __name__ == "__main__":