diff --git a/test/test_testing.py b/test/test_testing.py index b87345186cb3..9285166cb15e 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -3,9 +3,9 @@ import math from torch.testing._internal.common_utils import \ - (TestCase, run_tests, make_tensor) + (TestCase, make_tensor, run_tests, slowTest) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, onlyOnCPUAndCUDA, dtypes) + (instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA, dtypes) # For testing TestCase methods and torch.testing functions class TestTesting(TestCase): @@ -438,6 +438,54 @@ def test_assert_messages(self, device): self.assertEqual("no_user_msg", self._get_assert_msg(msg=None, debug_msg="no_user_msg")) self.assertEqual("debug_msg\nuser_msg", self._get_assert_msg(msg="user_msg", debug_msg="debug_msg")) + @onlyCUDA + @slowTest + def test_cuda_assert_should_stop_test_suite(self, device): + # This test is slow because it spawn another process to run another test suite. + import subprocess + import sys + + problematic_test_script = """\ +#!/usr/bin/env python + +import torch + +from torch.testing._internal.common_utils import (TestCase, run_tests) +from torch.testing._internal.common_device_type import instantiate_device_type_tests + +# This test is added to ensure that test suite terminates early when +# CUDA assert was thrown since all subsequent test will fail. +# See: https://github.com/pytorch/pytorch/issues/49019 +# This test file should be invoked from test_testing.py +class TestThatContainsCUDAAssertFailure(TestCase): + + def test_throw_unrecoverable_cuda_exception(self, device): + x = torch.rand(10, device=device) + # cause unrecoverable CUDA exception, recoverable on CPU + y = x[torch.tensor([25])].cpu() + + def test_trivial_passing_test_case_on_cpu_cuda(self, device): + x1 = torch.tensor([0., 1.], device=device) + x2 = torch.tensor([0., 1.], device='cpu') + self.assertEqual(x1, x2) + +instantiate_device_type_tests( + TestThatContainsCUDAAssertFailure, + globals(), + except_for=None +) + +if __name__ == '__main__': + run_tests() +""" + + # Test running of cuda assert test suite should early terminate. + p = subprocess.run([sys.executable, '-c', problematic_test_script], capture_output=True, timeout=120) + # should capture CUDA error + self.assertIn('CUDA error: device-side assert triggered', p.stderr.decode('ascii')) + # should run only 3 tests - 2 CPUs and 1 CUDA (remaining CUDA test should skip) + self.assertIn('Ran 3 tests', p.stderr.decode('ascii')) + instantiate_device_type_tests(TestTesting, globals()) if __name__ == '__main__': diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 36f02eff0c0f..73185116a4f5 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -187,6 +187,9 @@ def _construct_test_name(test_name, op, device_type, dtype): class DeviceTypeTestBase(TestCase): device_type: str = 'generic_device_type' + # Flag to disable test suite early due to unrecoverable error such as CUDA error. + _stop_test_suite = False + # Precision is a thread-local setting since it may be overridden per test _tls = threading.local() _tls.precision = TestCase._precision @@ -271,6 +274,11 @@ def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): self.precision = self._get_precision_override(test_fn, dtype) args = (arg for arg in (device_arg, dtype, op) if arg is not None) result = test_fn(self, *args) + except RuntimeError as rte: + if 'CUDA error: device-side assert triggered' in rte.__repr__(): + self._stop_test_suite = True + # raise the runtime error as is. + raise rte finally: self.precision = guard_precision @@ -313,6 +321,12 @@ def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None) + def run(self, result=None): + super().run(result=result) + # Early terminate test if _stop_test_suite is set. + if self._stop_test_suite: + result.stop() + class CPUTestBase(DeviceTypeTestBase): device_type = 'cpu'