From 7ca0b9a5883cdb2975ad1b8e8ab1f1103eb2c079 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 14 Nov 2025 16:27:29 -0800 Subject: [PATCH 1/4] Revert "[CI] Fix fbcode test_breakpoint error (#1132)" This reverts commit 6a98dc9087f0da6a6d408f39fc8be8eb170da46c. --- test/test_breakpoint.py | 80 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/test/test_breakpoint.py b/test/test_breakpoint.py index ebd956676..c0e4e5a65 100644 --- a/test/test_breakpoint.py +++ b/test/test_breakpoint.py @@ -3,7 +3,9 @@ import builtins from contextlib import contextmanager import os +import subprocess import sys +import textwrap from typing import TYPE_CHECKING import unittest from unittest import mock @@ -64,6 +66,42 @@ def kernel(x: torch.Tensor) -> torch.Tensor: return kernel + def _run_breakpoint_in_subprocess( + self, + *, + test_name: str, + runner_method: str, + triton_interpret: int, + helion_interpret: int, + ) -> None: + """Run a breakpoint test in a subprocess to isolate interpreter state.""" + script = textwrap.dedent( + f""" + from test.test_breakpoint import TestBreakpoint + + case = TestBreakpoint({test_name!r}) + case.setUp() + try: + getattr(case, {runner_method!r})(triton_interpret={triton_interpret}, helion_interpret={helion_interpret}) + finally: + case.tearDown() + """ + ) + + env = os.environ.copy() + result = subprocess.run( + [sys.executable, "-c", script], + env=env, + capture_output=True, + ) + if result.returncode != 0: + raise AssertionError( + f"{test_name} subprocess failed", + result.returncode, + result.stdout.decode(), + result.stderr.decode(), + ) + def _run_device_breakpoint_test( self, triton_interpret: int, helion_interpret: int ) -> None: @@ -91,13 +129,28 @@ def _run_device_breakpoint_test( torch.testing.assert_close(out, x) def test_device_breakpoint_no_interpret(self) -> None: - self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=0) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_device_breakpoint_test", + triton_interpret=0, + helion_interpret=0, + ) def test_device_breakpoint_triton_interpret(self) -> None: - self._run_device_breakpoint_test(triton_interpret=1, helion_interpret=0) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_device_breakpoint_test", + triton_interpret=1, + helion_interpret=0, + ) def test_device_breakpoint_helion_interpret(self) -> None: - self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=1) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_device_breakpoint_test", + triton_interpret=0, + helion_interpret=1, + ) def _run_host_breakpoint_test( self, triton_interpret: int, helion_interpret: int @@ -117,13 +170,28 @@ def _run_host_breakpoint_test( torch.testing.assert_close(out, x) def test_host_breakpoint_no_interpret(self) -> None: - self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=0) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_host_breakpoint_test", + triton_interpret=0, + helion_interpret=0, + ) def test_host_breakpoint_triton_interpret(self) -> None: - self._run_host_breakpoint_test(triton_interpret=1, helion_interpret=0) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_host_breakpoint_test", + triton_interpret=1, + helion_interpret=0, + ) def test_host_breakpoint_helion_interpret(self) -> None: - self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=1) + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_host_breakpoint_test", + triton_interpret=0, + helion_interpret=1, + ) if __name__ == "__main__": From c17becfe5251128ffdcc8100e541736d98e8afee Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 14 Nov 2025 16:47:36 -0800 Subject: [PATCH 2/4] skip breakpoint tests on fbcode --- test/test_breakpoint.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_breakpoint.py b/test/test_breakpoint.py index c0e4e5a65..f754f4bad 100644 --- a/test/test_breakpoint.py +++ b/test/test_breakpoint.py @@ -11,6 +11,7 @@ from unittest import mock import torch +from torch._environment import is_fbcode import helion from helion import exc @@ -128,6 +129,7 @@ def _run_device_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_device_breakpoint_no_interpret(self) -> None: self._run_breakpoint_in_subprocess( test_name=self._testMethodName, @@ -136,6 +138,7 @@ def test_device_breakpoint_no_interpret(self) -> None: helion_interpret=0, ) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_device_breakpoint_triton_interpret(self) -> None: self._run_breakpoint_in_subprocess( test_name=self._testMethodName, @@ -144,6 +147,7 @@ def test_device_breakpoint_triton_interpret(self) -> None: helion_interpret=0, ) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_device_breakpoint_helion_interpret(self) -> None: self._run_breakpoint_in_subprocess( test_name=self._testMethodName, @@ -169,6 +173,7 @@ def _run_host_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_host_breakpoint_no_interpret(self) -> None: self._run_breakpoint_in_subprocess( test_name=self._testMethodName, @@ -177,6 +182,7 @@ def test_host_breakpoint_no_interpret(self) -> None: helion_interpret=0, ) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_host_breakpoint_triton_interpret(self) -> None: self._run_breakpoint_in_subprocess( test_name=self._testMethodName, @@ -185,6 +191,7 @@ def test_host_breakpoint_triton_interpret(self) -> None: helion_interpret=0, ) + @unittest.skipIf(is_fbcode(), "subprocess test doesn't work in internal CI") def test_host_breakpoint_helion_interpret(self) -> None: self._run_breakpoint_in_subprocess( test_name=self._testMethodName, From 4a724d6d5fff3a6170d8e222f46f064cfa30ef07 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 14 Nov 2025 16:48:07 -0800 Subject: [PATCH 3/4] Fail CI job on any test error --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 51e196fcf..f4be7527d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -151,6 +151,7 @@ jobs: - name: Run Tests run: | + set -o pipefail source .venv/bin/activate # Conditionally enable ref-eager and golden-accept/dtype-assert test modes if [[ "${{ matrix.dtype-asserts }}" == "true" ]]; then export HELION_DEBUG_DTYPE_ASSERTS=1; fi From f49571a23d5f7ac5299934684a0e14cc0b532b89 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 14 Nov 2025 17:23:46 -0800 Subject: [PATCH 4/4] skip specific unit tests in ref eager mode --- test/test_indexing.py | 1 + test/test_unroll_tuples.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/test/test_indexing.py b/test/test_indexing.py index f06dac6f6..245b5c3ef 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -505,6 +505,7 @@ def run_case( expect_error=None, ) + @skipIfRefEager("specialization_key is not used in ref eager mode") def test_dynamic_shape_specialization_key_tracks_large_tensors(self) -> None: @helion.kernel(static_shapes=False) def passthrough(x: torch.Tensor) -> torch.Tensor: diff --git a/test/test_unroll_tuples.py b/test/test_unroll_tuples.py index c100917b6..9aef11fbf 100644 --- a/test/test_unroll_tuples.py +++ b/test/test_unroll_tuples.py @@ -10,6 +10,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfRefEager import helion.language as hl @@ -520,6 +521,7 @@ def kernel_static_range_tuple_indexing( expected = sum(tensors) torch.testing.assert_close(result, expected) + @skipIfRefEager("Type inference errors are not raised in ref eager mode") def test_static_range_tuple_indexing_requires_uniform_types(self): @helion.kernel(autotune_effort="none") def kernel_static_range_tuple_mismatch(x: torch.Tensor) -> torch.Tensor: