diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 51e196fcf..f4be7527d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -151,6 +151,7 @@ jobs: - name: Run Tests run: | + set -o pipefail source .venv/bin/activate # Conditionally enable ref-eager and golden-accept/dtype-assert test modes if [[ "${{ matrix.dtype-asserts }}" == "true" ]]; then export HELION_DEBUG_DTYPE_ASSERTS=1; fi diff --git a/test/test_breakpoint.py b/test/test_breakpoint.py index ebd956676..f754f4bad 100644 --- a/test/test_breakpoint.py +++ b/test/test_breakpoint.py @@ -3,12 +3,15 @@ import builtins from contextlib import contextmanager import os +import subprocess import sys +import textwrap from typing import TYPE_CHECKING import unittest from unittest import mock import torch +from torch._environment import is_fbcode import helion from helion import exc @@ -64,6 +67,42 @@ def kernel(x: torch.Tensor) -> torch.Tensor: return kernel + def _run_breakpoint_in_subprocess( + self, + *, + test_name: str, + runner_method: str, + triton_interpret: int, + helion_interpret: int, + ) -> None: + """Run a breakpoint test in a subprocess to isolate interpreter state.""" + script = textwrap.dedent( + f""" + from test.test_breakpoint import TestBreakpoint + + case = TestBreakpoint({test_name!r}) + case.setUp() + try: + getattr(case, {runner_method!r})(triton_interpret={triton_interpret}, helion_interpret={helion_interpret}) + finally: + case.tearDown() + """ + ) + + env = os.environ.copy() + result = subprocess.run( + [sys.executable, "-c", script], + env=env, + capture_output=True, + ) + if result.returncode != 0: + raise AssertionError( + f"{test_name} subprocess failed", + result.returncode, + result.stdout.decode(), + result.stderr.decode(), + ) + def _run_device_breakpoint_test( self, triton_interpret: int, helion_interpret: int ) -> None: @@ -90,14 +129,32 @@ def _run_device_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_device_breakpoint_no_interpret(self) -> None: - self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=0) - + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_device_breakpoint_test", + triton_interpret=0, + helion_interpret=0, + ) + + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_device_breakpoint_triton_interpret(self) -> None: - self._run_device_breakpoint_test(triton_interpret=1, helion_interpret=0) - + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_device_breakpoint_test", + triton_interpret=1, + helion_interpret=0, + ) + + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_device_breakpoint_helion_interpret(self) -> None: - self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=1) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_device_breakpoint_test", + triton_interpret=0, + helion_interpret=1, + ) def _run_host_breakpoint_test( self, triton_interpret: int, helion_interpret: int @@ -116,14 +173,32 @@ def _run_host_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_host_breakpoint_no_interpret(self) -> None: - self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=0) - + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_host_breakpoint_test", + triton_interpret=0, + helion_interpret=0, + ) + + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_host_breakpoint_triton_interpret(self) -> None: - self._run_host_breakpoint_test(triton_interpret=1, helion_interpret=0) - + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_host_breakpoint_test", + triton_interpret=1, + helion_interpret=0, + ) + + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_host_breakpoint_helion_interpret(self) -> None: - self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=1) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_host_breakpoint_test", + triton_interpret=0, + helion_interpret=1, + ) if __name__ == "__main__": diff --git a/test/test_indexing.py b/test/test_indexing.py index f06dac6f6..245b5c3ef 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -505,6 +505,7 @@ def run_case( expect_error=None, ) + @skipIfRefEager("specialization_key is not used in ref eager mode") def test_dynamic_shape_specialization_key_tracks_large_tensors(self) -> None: @helion.kernel(static_shapes=False) def passthrough(x: torch.Tensor) -> torch.Tensor: diff --git a/test/test_unroll_tuples.py b/test/test_unroll_tuples.py index c100917b6..9aef11fbf 100644 --- a/test/test_unroll_tuples.py +++ b/test/test_unroll_tuples.py @@ -10,6 +10,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfRefEager import helion.language as hl @@ -520,6 +521,7 @@ def kernel_static_range_tuple_indexing( expected = sum(tensors) torch.testing.assert_close(result, expected) + @skipIfRefEager("Type inference errors are not raised in ref eager mode") def test_static_range_tuple_indexing_requires_uniform_types(self): @helion.kernel(autotune_effort="none") def kernel_static_range_tuple_mismatch(x: torch.Tensor) -> torch.Tensor: