diff --git a/test/test_breakpoint.py b/test/test_breakpoint.py index f754f4bad..0c2fe7325 100644 --- a/test/test_breakpoint.py +++ b/test/test_breakpoint.py @@ -3,15 +3,13 @@ import builtins from contextlib import contextmanager import os -import subprocess import sys -import textwrap from typing import TYPE_CHECKING import unittest from unittest import mock import torch -from torch._environment import is_fbcode +import triton.runtime.interpreter as triton_interpreter import helion from helion import exc @@ -67,42 +65,6 @@ def kernel(x: torch.Tensor) -> torch.Tensor: return kernel - def _run_breakpoint_in_subprocess( - self, - *, - test_name: str, - runner_method: str, - triton_interpret: int, - helion_interpret: int, - ) -> None: - """Run a breakpoint test in a subprocess to isolate interpreter state.""" - script = textwrap.dedent( - f""" - from test.test_breakpoint import TestBreakpoint - - case = TestBreakpoint({test_name!r}) - case.setUp() - try: - getattr(case, {runner_method!r})(triton_interpret={triton_interpret}, helion_interpret={helion_interpret}) - finally: - case.tearDown() - """ - ) - - env = os.environ.copy() - result = subprocess.run( - [sys.executable, "-c", script], - env=env, - capture_output=True, - ) - if result.returncode != 0: - raise AssertionError( - f"{test_name} subprocess failed", - result.returncode, - result.stdout.decode(), - result.stderr.decode(), - ) - def _run_device_breakpoint_test( self, triton_interpret: int, helion_interpret: int ) -> None: @@ -129,32 +91,18 @@ def _run_device_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) - @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_device_breakpoint_no_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_device_breakpoint_test", - triton_interpret=0, - helion_interpret=0, - ) - - @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") + self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=0) + + @unittest.skipUnless( + hasattr(triton_interpreter, "_MISSING"), + "https://github.com/triton-lang/triton/pull/8735", + ) def test_device_breakpoint_triton_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_device_breakpoint_test", - triton_interpret=1, - helion_interpret=0, - ) - - @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") + self._run_device_breakpoint_test(triton_interpret=1, helion_interpret=0) + def test_device_breakpoint_helion_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_device_breakpoint_test", - triton_interpret=0, - helion_interpret=1, - ) + self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=1) def _run_host_breakpoint_test( self, triton_interpret: int, helion_interpret: int @@ -173,32 +121,18 @@ def _run_host_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) - @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_host_breakpoint_no_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_host_breakpoint_test", - triton_interpret=0, - helion_interpret=0, - ) - - @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") + self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=0) + + @unittest.skipUnless( + hasattr(triton_interpreter, "_MISSING"), + "https://github.com/triton-lang/triton/pull/8735", + ) def test_host_breakpoint_triton_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_host_breakpoint_test", - triton_interpret=1, - helion_interpret=0, - ) - - @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") + self._run_host_breakpoint_test(triton_interpret=1, helion_interpret=0) + def test_host_breakpoint_helion_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_host_breakpoint_test", - triton_interpret=0, - helion_interpret=1, - ) + self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=1) if __name__ == "__main__": diff --git a/test/test_print.py b/test/test_print.py index e3bb21fa7..52a36cdea 100644 --- a/test/test_print.py +++ b/test/test_print.py @@ -7,6 +7,7 @@ import pytest import torch +import triton.runtime.interpreter as triton_interpreter import helion from helion._testing import DEVICE @@ -99,7 +100,8 @@ def run_test_with_and_without_triton_interpret_envvar(self, test_func): os.environ.pop("TRITON_INTERPRET", None) test_func(interpret_mode=False) - # Then run with TRITON_INTERPRET=1 + if not hasattr(triton_interpreter, "_MISSING"): + return # see https://github.com/triton-lang/triton/pull/8735 os.environ["TRITON_INTERPRET"] = "1" test_func(interpret_mode=True) finally: