Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import sympy
import torch
from torch._dynamo.source import EphemeralSource
from torch._dynamo.source import LocalSource
from torch._inductor.runtime.runtime_utils import next_power_of_2
from torch._inductor.utils import triton_type
Expand All @@ -21,6 +22,8 @@
from .. import exc
from ..language.constexpr import ConstExpr
from .loop_dependency_checker import LoopDependencyChecker
from .source_location import SourceLocation
from .source_location import current_location
from .variable_origin import BlockSizeOrigin
from .variable_origin import Origin

Expand All @@ -41,6 +44,27 @@ class _TLS(Protocol):
tls: _TLS = typing.cast("_TLS", threading.local())


class HelionKernelSource(EphemeralSource):
"""Ephemeral source that formats as a kernel file location."""

def __init__(self, location: SourceLocation) -> None:
super().__init__(desc=None)
self.location = location

def name(self) -> str: # type: ignore[override]
formatted = self.location.format().rstrip("\n")
if not formatted:
return ""
return "\nHelion kernel stack:\n" + formatted


def _current_symbol_source() -> EphemeralSource | None:
location = current_location()
if not location:
return None
return HelionKernelSource(location)


class CompileEnvironment:
"""
Global state for the duration of a compilation.
Expand Down Expand Up @@ -154,8 +178,9 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
return self.block_sizes[rdim_idx]

def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
source = _current_symbol_source()
with self.shape_env.ignore_fresh_unbacked_symbols():
sym = self.shape_env.create_unbacked_symint()
sym = self.shape_env.create_unbacked_symint(source=source)
# self.shape_env.guards.append(
# ShapeGuard(
# sympy.Ne(sym._sympy_(), 0),
Expand All @@ -172,8 +197,9 @@ def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
return sym

def create_unbacked_symint(self, hint: int = 8192) -> torch.SymInt:
source = _current_symbol_source()
with self.shape_env.ignore_fresh_unbacked_symbols():
sym = self.shape_env.create_unbacked_symint()
sym = self.shape_env.create_unbacked_symint(source=source)
# TODO(jansel): this is a hack to get us past some == 1 checks
# we should probably have a better way to handle this
self.shape_env.var_to_val[sym._sympy_()] = sympy.sympify(hint)
Expand Down
109 changes: 109 additions & 0 deletions test/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

import inspect
import logging
import os
from typing import Callable
import unittest

import torch
import torch.fx.experimental._config as fx_config

import helion
from helion._testing import DEVICE
Expand All @@ -11,6 +16,48 @@
import helion.language as hl


@helion.kernel(use_default_config=True)
def logging_reduce_rows(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape
n = hl.specialize(n)

m_block = hl.register_block_size(m)

result = torch.zeros(n, dtype=torch.float32, device=x.device)

for outer in hl.tile(m, block_size=m_block):
for inner in hl.tile(outer.begin, outer.end):
zero_idx = inner.begin - inner.begin
result[zero_idx] = result[zero_idx]
return result


def _run_symbol_logging_example() -> None:
x = torch.randn((128, 5632), device=DEVICE, dtype=torch.float16)
logging_reduce_rows(x)


def _run_with_symbol_logs(fn: Callable[[], None]) -> str:
logger = logging.getLogger("torch.fx.experimental.symbolic_shapes")
records: list[str] = []

class _Capture(logging.Handler):
def emit(self, record: logging.LogRecord) -> None: # type: ignore[override]
records.append(record.getMessage())

handler = _Capture()
previous_level = logger.level
logger.addHandler(handler)
logger.setLevel(logging.INFO)
try:
fn()
finally:
logger.removeHandler(handler)
logger.setLevel(previous_level)

return "\n".join(records)


class TestLogging(RefEagerTestDisabled, TestCase):
def test_log_set(self):
import logging
Expand Down Expand Up @@ -51,6 +98,68 @@ def add(x, y):
any("DEBUG:helion.runtime.kernel:Debug string:" in msg for msg in cm.output)
)

def test_symbolic_shape_log_includes_kernel_source(self):
symbol_filter = ",".join(f"u{i}" for i in range(5))
file_path = os.path.abspath(__file__)
source_lines, start_line = inspect.getsourcelines(logging_reduce_rows.fn)

def get_line_no(snippet: str) -> int:
return start_line + next(
idx for idx, line in enumerate(source_lines) if snippet in line
)

line_for_block = get_line_no("m_block = hl.register_block_size(m)")
line_for_inner = get_line_no("for inner in hl.tile(outer.begin, outer.end):")
line_for_zero = get_line_no("zero_idx = inner.begin - inner.begin")

with fx_config.patch(extended_debug_create_symbol=symbol_filter):
output = _run_with_symbol_logs(_run_symbol_logging_example)

lines = output.splitlines()

self.assertTrue(output, msg="no logs captured for symbolic shapes")
self.assertIn("create_unbacked_symint", output)
self.assertIn("Helion kernel stack:", output)
self.assertTrue(
any(
f' File "{file_path}", line {line_for_block}, in logging_reduce_rows'
in line
for line in lines
),
msg="register_block_size location missing",
)
self.assertIn(" m_block = hl.register_block_size(m)", output)

def assert_symbol(symbol: str, lineno: int, snippet: str) -> None:
marker = f"create_unbacked_symint {symbol}"
for idx, line in enumerate(lines):
if marker in line:
window = lines[idx : idx + 6]
break
else:
self.fail(f"missing log for {symbol}")

expected_file_line = (
f' File "{file_path}", line {lineno}, in logging_reduce_rows'
)
self.assertTrue(
any(expected_file_line in line for line in window),
msg=f"missing file info for {symbol}",
)
self.assertTrue(
any(snippet in line for line in window),
msg=f"missing source line for {symbol}",
)

for sym, line_no, snippet in [
("u0", line_for_block, "m_block = hl.register_block_size(m)"),
("u1", line_for_inner, "for inner in hl.tile(outer.begin, outer.end):"),
("u2", line_for_inner, "for inner in hl.tile(outer.begin, outer.end):"),
("u3", line_for_inner, "for inner in hl.tile(outer.begin, outer.end):"),
("u4", line_for_zero, "zero_idx = inner.begin - inner.begin"),
]:
assert_symbol(sym, line_no, snippet)


if __name__ == "__main__":
unittest.main()
Loading