Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ class GraphModuleUnsupportedOps(BaseError):
message = "GraphModule contains unsupported operations: {0}. Only pure computation graphs are supported (no load_attr or call_module ops)."


class RefEagerModeCodePrintError(BaseError):
message = "No generated code to print out if ref eager mode is enabled."


class NoDeviceLoopsInKernel(BaseError):
message = (
"Kernel contains no device loops. Add an hl.tile(...) or hl.grid(...) loop "
Expand Down
10 changes: 10 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ def __init__(self, **settings: object) -> None:
Args:
settings: Keyword arguments representing various settings.
"""

if defaults := getattr(_tls, "default_settings", None):
settings = {**defaults.to_dict(), **settings}

super().__init__(**settings) # pyright: ignore[reportArgumentType]

self._check_ref_eager_mode_before_print_output_code()

def to_dict(self) -> dict[str, object]:
"""
Convert the Settings object to a dictionary.
Expand Down Expand Up @@ -162,6 +165,13 @@ def check_autotuning_disabled(self) -> None:
if msg:
raise exc.AutotuningDisallowedInEnvironment(msg)

def _check_ref_eager_mode_before_print_output_code(self) -> None:
"""
Check if ref eager mode is enabled before printing output code. If ref eager mode is enabled, raise an error.
"""
if self.ref_mode == RefMode.EAGER and self.print_output_code:
raise exc.RefEagerModeCodePrintError

@staticmethod
def default() -> Settings:
"""
Expand Down
76 changes: 76 additions & 0 deletions test/test_print_ref_eager_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

import contextlib
import io
import unittest

import pytest
import torch

import helion
from helion import exc
from helion._testing import TestCase
import helion.language as hl


class TestPrintOutputCode(TestCase):
def test_ref_eager_mode_code_print_error(self):
"""Test that RefEagerModeCodePrintError is raised when using @helion.kernel with both settings"""

with pytest.raises(exc.RefEagerModeCodePrintError):

@helion.kernel(
use_default_config=True,
print_output_code=True,
ref_mode=helion.RefMode.EAGER,
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x, y = torch.broadcast_tensors(x, y)
out = torch.empty(
x.shape,
dtype=torch.promote_types(x.dtype, y.dtype),
device=x.device,
)
for tile in hl.tile(out.size()):
out[tile] = x[tile] + y[tile]
return out

x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
torch.testing.assert_close(add(x, y), torch.add(x, y))

def test_normal_mode_code_print(self):
"""Test that output code is in stderr when using @helion.kernel with normal mode"""

f = io.StringIO()
with contextlib.redirect_stderr(f):

@helion.kernel(
use_default_config=True,
print_output_code=True,
ref_mode=helion.RefMode.OFF,
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x, y = torch.broadcast_tensors(x, y)
out = torch.empty(
x.shape,
dtype=torch.promote_types(x.dtype, y.dtype),
device=x.device,
)
for tile in hl.tile(out.size()):
out[tile] = x[tile] + y[tile]
return out

x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
torch.testing.assert_close(add(x, y), torch.add(x, y))

self.assertNotEqual(
f.getvalue(),
"",
"Output code in stderr should not be empty at normal mode.",
)


if __name__ == "__main__":
unittest.main()
Loading