Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def _render_input_arg_assignment(name: str, value: object) -> list[str]:
output_lines.extend(["", "helion_repro_caller()"])

output_lines.append("# === END HELION KERNEL REPRO ===")
repro_text = "\n".join(output_lines)
repro_text = "\n" + "\n".join(output_lines)
log_func(repro_text)


Expand Down
72 changes: 72 additions & 0 deletions test/test_debug_utils.expected
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,75 @@ def helion_repro_caller():

helion_repro_caller()
# === END HELION KERNEL REPRO ===

--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_autotune_error)
# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[64], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=8, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
def kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.shape[0]
for tile_n in hl.tile([n]):
out[tile_n] = x[tile_n] + 1
return out

def helion_repro_caller():
torch.manual_seed(0)
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
return kernel(x)

helion_repro_caller()
# === END HELION KERNEL REPRO ===

--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_device_ir_lowering_error)
# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[32], indexing=[], load_eviction_policies=[], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.shape[0]
for tile_n in hl.tile([n]):
# Using torch.nonzero inside device loop causes compilation error
# because it produces data-dependent output shape
torch.nonzero(x[tile_n])
out[tile_n] = x[tile_n]
return out

def helion_repro_caller():
torch.manual_seed(0)
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
return kernel_with_compile_error(x)

helion_repro_caller()
# === END HELION KERNEL REPRO ===

--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_triton_codegen_error)
# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
def kernel_with_triton_error(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.shape[0]
for tile_n in hl.tile([n]):
out[tile_n] = x[tile_n] + 1
return out

def helion_repro_caller():
torch.manual_seed(0)
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
return kernel_with_triton_error(x)

helion_repro_caller()
# === END HELION KERNEL REPRO ===
100 changes: 59 additions & 41 deletions test/test_debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,41 @@ def kernel(x: torch.Tensor) -> torch.Tensor:

return kernel

def _extract_repro_script(self, text: str) -> str:
"""Extract the repro code block between markers (including markers).

Args:
text: The text containing the repro block. Can be a full string or log_capture object.

Returns:
The extracted repro block including both markers.
"""
# If it's a log capture object, extract the repro script from logs first
if hasattr(text, "records"):
log_capture = text
repro_script = None
for record in log_capture.records:
if "# === HELION KERNEL REPRO ===" in record.message:
repro_script = record.message
break
if repro_script is None:
self.fail("No repro script found in logs")
text = repro_script

# Extract code block between markers
start_marker = "# === HELION KERNEL REPRO ==="
end_marker = "# === END HELION KERNEL REPRO ==="
start_idx = text.find(start_marker)
end_idx = text.find(end_marker)

if start_idx == -1:
self.fail("Start marker not found")
if end_idx == -1:
self.fail("End marker not found")

# Extract content including both markers
return text[start_idx : end_idx + len(end_marker)].strip()

def test_print_repro_env_var(self):
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
with self._with_print_repro_enabled():
Expand All @@ -83,15 +118,8 @@ def test_print_repro_env_var(self):
result = kernel(x)
torch.testing.assert_close(result, x + 1)

# Extract repro script from logs (use records to get the raw message without formatting)
repro_script = None
for record in log_capture.records:
if "# === HELION KERNEL REPRO ===" in record.message:
repro_script = record.message
break

if repro_script is None:
self.fail("No repro script found in logs")
# Extract repro script from logs
repro_script = self._extract_repro_script(log_capture)

# Normalize range_warp_specializes=[None] to [] for comparison
normalized_script = repro_script.replace(
Expand Down Expand Up @@ -163,10 +191,14 @@ def mock_do_bench(*args, **kwargs):
captured = "".join(output_capture.readouterr())

# Verify that a repro script was printed for the failing config
self.assertIn("# === HELION KERNEL REPRO ===", captured)
self.assertIn("# === END HELION KERNEL REPRO ===", captured)
self.assertIn("kernel", captured)
self.assertIn("helion_repro_caller()", captured)
repro_script = self._extract_repro_script(captured)

# Normalize range_warp_specializes=[None] to [] for comparison
normalized_script = repro_script.replace(
"range_warp_specializes=[None]", "range_warp_specializes=[]"
)

self.assertExpectedJournal(normalized_script)

def test_print_repro_on_device_ir_lowering_error(self):
"""Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering."""
Expand All @@ -192,21 +224,14 @@ def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor:
kernel_with_compile_error(x)

# Extract repro script from logs
repro_script = None
for record in log_capture.records:
if "# === HELION KERNEL REPRO ===" in record.message:
repro_script = record.message
break

# Verify that a repro script was printed when compilation failed
self.assertIsNotNone(
repro_script,
"Expected repro script to be printed when device IR lowering fails",
repro_script = self._extract_repro_script(log_capture)

# Normalize range_warp_specializes=[None] to [] for comparison
normalized_script = repro_script.replace(
"range_warp_specializes=[None]", "range_warp_specializes=[]"
)
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
self.assertIn("kernel_with_compile_error", repro_script)
self.assertIn("helion_repro_caller()", repro_script)

self.assertExpectedJournal(normalized_script)

def test_print_repro_on_triton_codegen_error(self):
"""Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails."""
Expand Down Expand Up @@ -242,21 +267,14 @@ def mock_load(code, *args, **kwargs):
kernel_with_triton_error(x)

# Extract repro script from logs
repro_script = None
for record in log_capture.records:
if "# === HELION KERNEL REPRO ===" in record.message:
repro_script = record.message
break

# Verify that a repro script was printed when Triton codegen failed
self.assertIsNotNone(
repro_script,
"Expected repro script to be printed when Triton codegen fails",
repro_script = self._extract_repro_script(log_capture)

# Normalize range_warp_specializes=[None] to [] for comparison
normalized_script = repro_script.replace(
"range_warp_specializes=[None]", "range_warp_specializes=[]"
)
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
self.assertIn("kernel_with_triton_error", repro_script)
self.assertIn("helion_repro_caller()", repro_script)

self.assertExpectedJournal(normalized_script)


if __name__ == "__main__":
Expand Down
Loading