From d52984e59d1911b1b3bd74d0b82ed85fa4aee535 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 1 Oct 2025 20:40:17 -0700 Subject: [PATCH] Print Helion kernel source line in symbolic shape debugging With `TORCH_LOGS="+dynamic"`, now it prints: ``` I1001 20:49:16.995000 1208064 torch/fx/experimental/symbolic_shapes.py:4800] create_unbacked_symint u0 [-int_oo, int_oo] I1001 20:49:16.995000 1208064 torch/fx/experimental/symbolic_shapes.py:4800] Helion kernel stack: I1001 20:49:16.995000 1208064 torch/fx/experimental/symbolic_shapes.py:4800] File "/home/willfeng/local/helion2/minimal_repro.py", line 10, in logging_reduce_rows I1001 20:49:16.995000 1208064 torch/fx/experimental/symbolic_shapes.py:4800] m_block = hl.register_block_size(m) I1001 20:49:16.995000 1208064 torch/fx/experimental/symbolic_shapes.py:4800] ^^^^^^^^^^^^^^^^^^^^^^^^^ ``` which can be further traced with `TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"` stack-info: PR: https://github.com/pytorch/helion/pull/771, branch: yf225/stack/61 --- helion/_compiler/compile_environment.py | 30 ++++++- test/test_logging.py | 109 ++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 2 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index bf2a45b3b..e23dfb12d 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -12,6 +12,7 @@ import sympy import torch +from torch._dynamo.source import EphemeralSource from torch._dynamo.source import LocalSource from torch._inductor.runtime.runtime_utils import next_power_of_2 from torch._inductor.utils import triton_type @@ -21,6 +22,8 @@ from .. import exc from ..language.constexpr import ConstExpr from .loop_dependency_checker import LoopDependencyChecker +from .source_location import SourceLocation +from .source_location import current_location from .variable_origin import BlockSizeOrigin from .variable_origin import Origin @@ -41,6 +44,27 @@ class _TLS(Protocol): tls: _TLS = typing.cast("_TLS", threading.local()) +class HelionKernelSource(EphemeralSource): + """Ephemeral source that formats as a kernel file location.""" + + def __init__(self, location: SourceLocation) -> None: + super().__init__(desc=None) + self.location = location + + def name(self) -> str: # type: ignore[override] + formatted = self.location.format().rstrip("\n") + if not formatted: + return "" + return "\nHelion kernel stack:\n" + formatted + + +def _current_symbol_source() -> EphemeralSource | None: + location = current_location() + if not location: + return None + return HelionKernelSource(location) + + class CompileEnvironment: """ Global state for the duration of a compilation. @@ -154,8 +178,9 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf return self.block_sizes[rdim_idx] def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt: + source = _current_symbol_source() with self.shape_env.ignore_fresh_unbacked_symbols(): - sym = self.shape_env.create_unbacked_symint() + sym = self.shape_env.create_unbacked_symint(source=source) # self.shape_env.guards.append( # ShapeGuard( # sympy.Ne(sym._sympy_(), 0), @@ -172,8 +197,9 @@ def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt: return sym def create_unbacked_symint(self, hint: int = 8192) -> torch.SymInt: + source = _current_symbol_source() with self.shape_env.ignore_fresh_unbacked_symbols(): - sym = self.shape_env.create_unbacked_symint() + sym = self.shape_env.create_unbacked_symint(source=source) # TODO(jansel): this is a hack to get us past some == 1 checks # we should probably have a better way to handle this self.shape_env.var_to_val[sym._sympy_()] = sympy.sympify(hint) diff --git a/test/test_logging.py b/test/test_logging.py index bedbdc681..f41d9bb06 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -1,8 +1,13 @@ from __future__ import annotations +import inspect +import logging +import os +from typing import Callable import unittest import torch +import torch.fx.experimental._config as fx_config import helion from helion._testing import DEVICE @@ -11,6 +16,48 @@ import helion.language as hl +@helion.kernel(use_default_config=True) +def logging_reduce_rows(x: torch.Tensor) -> torch.Tensor: + m, n = x.shape + n = hl.specialize(n) + + m_block = hl.register_block_size(m) + + result = torch.zeros(n, dtype=torch.float32, device=x.device) + + for outer in hl.tile(m, block_size=m_block): + for inner in hl.tile(outer.begin, outer.end): + zero_idx = inner.begin - inner.begin + result[zero_idx] = result[zero_idx] + return result + + +def _run_symbol_logging_example() -> None: + x = torch.randn((128, 5632), device=DEVICE, dtype=torch.float16) + logging_reduce_rows(x) + + +def _run_with_symbol_logs(fn: Callable[[], None]) -> str: + logger = logging.getLogger("torch.fx.experimental.symbolic_shapes") + records: list[str] = [] + + class _Capture(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: # type: ignore[override] + records.append(record.getMessage()) + + handler = _Capture() + previous_level = logger.level + logger.addHandler(handler) + logger.setLevel(logging.INFO) + try: + fn() + finally: + logger.removeHandler(handler) + logger.setLevel(previous_level) + + return "\n".join(records) + + class TestLogging(RefEagerTestDisabled, TestCase): def test_log_set(self): import logging @@ -51,6 +98,68 @@ def add(x, y): any("DEBUG:helion.runtime.kernel:Debug string:" in msg for msg in cm.output) ) + def test_symbolic_shape_log_includes_kernel_source(self): + symbol_filter = ",".join(f"u{i}" for i in range(5)) + file_path = os.path.abspath(__file__) + source_lines, start_line = inspect.getsourcelines(logging_reduce_rows.fn) + + def get_line_no(snippet: str) -> int: + return start_line + next( + idx for idx, line in enumerate(source_lines) if snippet in line + ) + + line_for_block = get_line_no("m_block = hl.register_block_size(m)") + line_for_inner = get_line_no("for inner in hl.tile(outer.begin, outer.end):") + line_for_zero = get_line_no("zero_idx = inner.begin - inner.begin") + + with fx_config.patch(extended_debug_create_symbol=symbol_filter): + output = _run_with_symbol_logs(_run_symbol_logging_example) + + lines = output.splitlines() + + self.assertTrue(output, msg="no logs captured for symbolic shapes") + self.assertIn("create_unbacked_symint", output) + self.assertIn("Helion kernel stack:", output) + self.assertTrue( + any( + f' File "{file_path}", line {line_for_block}, in logging_reduce_rows' + in line + for line in lines + ), + msg="register_block_size location missing", + ) + self.assertIn(" m_block = hl.register_block_size(m)", output) + + def assert_symbol(symbol: str, lineno: int, snippet: str) -> None: + marker = f"create_unbacked_symint {symbol}" + for idx, line in enumerate(lines): + if marker in line: + window = lines[idx : idx + 6] + break + else: + self.fail(f"missing log for {symbol}") + + expected_file_line = ( + f' File "{file_path}", line {lineno}, in logging_reduce_rows' + ) + self.assertTrue( + any(expected_file_line in line for line in window), + msg=f"missing file info for {symbol}", + ) + self.assertTrue( + any(snippet in line for line in window), + msg=f"missing source line for {symbol}", + ) + + for sym, line_no, snippet in [ + ("u0", line_for_block, "m_block = hl.register_block_size(m)"), + ("u1", line_for_inner, "for inner in hl.tile(outer.begin, outer.end):"), + ("u2", line_for_inner, "for inner in hl.tile(outer.begin, outer.end):"), + ("u3", line_for_inner, "for inner in hl.tile(outer.begin, outer.end):"), + ("u4", line_for_zero, "zero_idx = inner.begin - inner.begin"), + ]: + assert_symbol(sym, line_no, snippet) + if __name__ == "__main__": unittest.main()