diff --git a/helion/exc.py b/helion/exc.py index 71bd0d266..7647dbfbf 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -382,6 +382,10 @@ class GraphModuleUnsupportedOps(BaseError): message = "GraphModule contains unsupported operations: {0}. Only pure computation graphs are supported (no load_attr or call_module ops)." +class RefEagerModeCodePrintError(BaseError): + message = "No generated code to print out if ref eager mode is enabled." + + class NoDeviceLoopsInKernel(BaseError): message = ( "Kernel contains no device loops. Add an hl.tile(...) or hl.grid(...) loop " diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index b8bc238ac..75a09c075 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -128,11 +128,14 @@ def __init__(self, **settings: object) -> None: Args: settings: Keyword arguments representing various settings. """ + if defaults := getattr(_tls, "default_settings", None): settings = {**defaults.to_dict(), **settings} super().__init__(**settings) # pyright: ignore[reportArgumentType] + self._check_ref_eager_mode_before_print_output_code() + def to_dict(self) -> dict[str, object]: """ Convert the Settings object to a dictionary. @@ -162,6 +165,13 @@ def check_autotuning_disabled(self) -> None: if msg: raise exc.AutotuningDisallowedInEnvironment(msg) + def _check_ref_eager_mode_before_print_output_code(self) -> None: + """ + Check if ref eager mode is enabled before printing output code. If ref eager mode is enabled, raise an error. + """ + if self.ref_mode == RefMode.EAGER and self.print_output_code: + raise exc.RefEagerModeCodePrintError + @staticmethod def default() -> Settings: """ diff --git a/test/test_print_ref_eager_mode.py b/test/test_print_ref_eager_mode.py new file mode 100644 index 000000000..11bc5b8b4 --- /dev/null +++ b/test/test_print_ref_eager_mode.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import contextlib +import io +import unittest + +import pytest +import torch + +import helion +from helion import exc +from helion._testing import TestCase +import helion.language as hl + + +class TestPrintOutputCode(TestCase): + def test_ref_eager_mode_code_print_error(self): + """Test that RefEagerModeCodePrintError is raised when using @helion.kernel with both settings""" + + with pytest.raises(exc.RefEagerModeCodePrintError): + + @helion.kernel( + use_default_config=True, + print_output_code=True, + ref_mode=helion.RefMode.EAGER, + ) + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x, y = torch.broadcast_tensors(x, y) + out = torch.empty( + x.shape, + dtype=torch.promote_types(x.dtype, y.dtype), + device=x.device, + ) + for tile in hl.tile(out.size()): + out[tile] = x[tile] + y[tile] + return out + + x = torch.randn([512, 512], device="cuda", dtype=torch.float16) + y = torch.randn([512, 512], device="cuda", dtype=torch.float16) + torch.testing.assert_close(add(x, y), torch.add(x, y)) + + def test_normal_mode_code_print(self): + """Test that output code is in stderr when using @helion.kernel with normal mode""" + + f = io.StringIO() + with contextlib.redirect_stderr(f): + + @helion.kernel( + use_default_config=True, + print_output_code=True, + ref_mode=helion.RefMode.OFF, + ) + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x, y = torch.broadcast_tensors(x, y) + out = torch.empty( + x.shape, + dtype=torch.promote_types(x.dtype, y.dtype), + device=x.device, + ) + for tile in hl.tile(out.size()): + out[tile] = x[tile] + y[tile] + return out + + x = torch.randn([512, 512], device="cuda", dtype=torch.float16) + y = torch.randn([512, 512], device="cuda", dtype=torch.float16) + torch.testing.assert_close(add(x, y), torch.add(x, y)) + + self.assertNotEqual( + f.getvalue(), + "", + "Output code in stderr should not be empty at normal mode.", + ) + + +if __name__ == "__main__": + unittest.main()