Skip to content

Commit

Permalink
early terminate when CUDA assert were thrown (#49527)
Browse files Browse the repository at this point in the history
Summary:
Fixes #49019

I marked the test_testing function as slow since it took ~1 minute to finish the subprocess test suite.

Pull Request resolved: #49527

Reviewed By: malfet

Differential Revision: D25623219

Pulled By: walterddr

fbshipit-source-id: 1b414623ecce14aace5e0996d5e4768a40e12e06
  • Loading branch information
Rong Rong (AI Infra) authored and facebook-github-bot committed Dec 22, 2020
1 parent 9b6fb85 commit be09160
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
52 changes: 50 additions & 2 deletions test/test_testing.py
Expand Up @@ -3,9 +3,9 @@
import math

from torch.testing._internal.common_utils import \
(TestCase, run_tests, make_tensor)
(TestCase, make_tensor, run_tests, slowTest)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyOnCPUAndCUDA, dtypes)
(instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA, dtypes)

# For testing TestCase methods and torch.testing functions
class TestTesting(TestCase):
Expand Down Expand Up @@ -438,6 +438,54 @@ def test_assert_messages(self, device):
self.assertEqual("no_user_msg", self._get_assert_msg(msg=None, debug_msg="no_user_msg"))
self.assertEqual("debug_msg\nuser_msg", self._get_assert_msg(msg="user_msg", debug_msg="debug_msg"))

@onlyCUDA
@slowTest
def test_cuda_assert_should_stop_test_suite(self, device):
# This test is slow because it spawn another process to run another test suite.
import subprocess
import sys

problematic_test_script = """\
#!/usr/bin/env python
import torch
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
# This test is added to ensure that test suite terminates early when
# CUDA assert was thrown since all subsequent test will fail.
# See: https://github.com/pytorch/pytorch/issues/49019
# This test file should be invoked from test_testing.py
class TestThatContainsCUDAAssertFailure(TestCase):
def test_throw_unrecoverable_cuda_exception(self, device):
x = torch.rand(10, device=device)
# cause unrecoverable CUDA exception, recoverable on CPU
y = x[torch.tensor([25])].cpu()
def test_trivial_passing_test_case_on_cpu_cuda(self, device):
x1 = torch.tensor([0., 1.], device=device)
x2 = torch.tensor([0., 1.], device='cpu')
self.assertEqual(x1, x2)
instantiate_device_type_tests(
TestThatContainsCUDAAssertFailure,
globals(),
except_for=None
)
if __name__ == '__main__':
run_tests()
"""

# Test running of cuda assert test suite should early terminate.
p = subprocess.run([sys.executable, '-c', problematic_test_script], capture_output=True, timeout=120)
# should capture CUDA error
self.assertIn('CUDA error: device-side assert triggered', p.stderr.decode('ascii'))
# should run only 3 tests - 2 CPUs and 1 CUDA (remaining CUDA test should skip)
self.assertIn('Ran 3 tests', p.stderr.decode('ascii'))

instantiate_device_type_tests(TestTesting, globals())

if __name__ == '__main__':
Expand Down
14 changes: 14 additions & 0 deletions torch/testing/_internal/common_device_type.py
Expand Up @@ -187,6 +187,9 @@ def _construct_test_name(test_name, op, device_type, dtype):
class DeviceTypeTestBase(TestCase):
device_type: str = 'generic_device_type'

# Flag to disable test suite early due to unrecoverable error such as CUDA error.
_stop_test_suite = False

# Precision is a thread-local setting since it may be overridden per test
_tls = threading.local()
_tls.precision = TestCase._precision
Expand Down Expand Up @@ -271,6 +274,11 @@ def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op):
self.precision = self._get_precision_override(test_fn, dtype)
args = (arg for arg in (device_arg, dtype, op) if arg is not None)
result = test_fn(self, *args)
except RuntimeError as rte:
if 'CUDA error: device-side assert triggered' in rte.__repr__():
self._stop_test_suite = True
# raise the runtime error as is.
raise rte
finally:
self.precision = guard_precision

Expand Down Expand Up @@ -313,6 +321,12 @@ def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op):
for dtype in dtypes:
instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None)

def run(self, result=None):
super().run(result=result)
# Early terminate test if _stop_test_suite is set.
if self._stop_test_suite:
result.stop()


class CPUTestBase(DeviceTypeTestBase):
device_type = 'cpu'
Expand Down

0 comments on commit be09160

Please sign in to comment.