From 2d12aa545d1153165b2722c526b6c783a0ff1baf Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 13 Nov 2025 20:13:10 -0800 Subject: [PATCH] [fbcode] Fix test_breakpoint error --- test/test_breakpoint.py | 80 ++++------------------------------------- 1 file changed, 6 insertions(+), 74 deletions(-) diff --git a/test/test_breakpoint.py b/test/test_breakpoint.py index c0e4e5a65..ebd956676 100644 --- a/test/test_breakpoint.py +++ b/test/test_breakpoint.py @@ -3,9 +3,7 @@ import builtins from contextlib import contextmanager import os -import subprocess import sys -import textwrap from typing import TYPE_CHECKING import unittest from unittest import mock @@ -66,42 +64,6 @@ def kernel(x: torch.Tensor) -> torch.Tensor: return kernel - def _run_breakpoint_in_subprocess( - self, - *, - test_name: str, - runner_method: str, - triton_interpret: int, - helion_interpret: int, - ) -> None: - """Run a breakpoint test in a subprocess to isolate interpreter state.""" - script = textwrap.dedent( - f""" - from test.test_breakpoint import TestBreakpoint - - case = TestBreakpoint({test_name!r}) - case.setUp() - try: - getattr(case, {runner_method!r})(triton_interpret={triton_interpret}, helion_interpret={helion_interpret}) - finally: - case.tearDown() - """ - ) - - env = os.environ.copy() - result = subprocess.run( - [sys.executable, "-c", script], - env=env, - capture_output=True, - ) - if result.returncode != 0: - raise AssertionError( - f"{test_name} subprocess failed", - result.returncode, - result.stdout.decode(), - result.stderr.decode(), - ) - def _run_device_breakpoint_test( self, triton_interpret: int, helion_interpret: int ) -> None: @@ -129,28 +91,13 @@ def _run_device_breakpoint_test( torch.testing.assert_close(out, x) def test_device_breakpoint_no_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_device_breakpoint_test", - triton_interpret=0, - helion_interpret=0, - ) + self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=0) def test_device_breakpoint_triton_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_device_breakpoint_test", - triton_interpret=1, - helion_interpret=0, - ) + self._run_device_breakpoint_test(triton_interpret=1, helion_interpret=0) def test_device_breakpoint_helion_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_device_breakpoint_test", - triton_interpret=0, - helion_interpret=1, - ) + self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=1) def _run_host_breakpoint_test( self, triton_interpret: int, helion_interpret: int @@ -170,28 +117,13 @@ def _run_host_breakpoint_test( torch.testing.assert_close(out, x) def test_host_breakpoint_no_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_host_breakpoint_test", - triton_interpret=0, - helion_interpret=0, - ) + self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=0) def test_host_breakpoint_triton_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_host_breakpoint_test", - triton_interpret=1, - helion_interpret=0, - ) + self._run_host_breakpoint_test(triton_interpret=1, helion_interpret=0) def test_host_breakpoint_helion_interpret(self) -> None: - self._run_breakpoint_in_subprocess( - test_name=self._testMethodName, - runner_method="_run_host_breakpoint_test", - triton_interpret=0, - helion_interpret=1, - ) + self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=1) if __name__ == "__main__":