Skip to content

Commit

Permalink
Reset grad state across unittests (#126345)
Browse files Browse the repository at this point in the history
Pull Request resolved: #126345
Approved by: https://github.com/ezyang
  • Loading branch information
williamwen42 authored and pytorchmergebot committed May 23, 2024
1 parent a31a60d commit d11e44c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
8 changes: 1 addition & 7 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,7 @@


class AOTTestCase(TestCase):
def setUp(self):
self.prev_grad_state = torch.is_grad_enabled()
super().setUp()

def tearDown(self):
torch.set_grad_enabled(self.prev_grad_state)
super().tearDown()
pass


class TestPythonKey(AOTTestCase):
Expand Down
7 changes: 7 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2903,6 +2903,9 @@ def setUp(self):
if self._default_dtype_check_enabled:
assert torch.get_default_dtype() == torch.float

# attempt to reset some global state at the end of the test
self._prev_grad_state = torch.is_grad_enabled()

def tearDown(self):
# There exists test cases that override TestCase.setUp
# definition, so we cannot assume that _check_invariants
Expand All @@ -2917,6 +2920,10 @@ def tearDown(self):
if self._default_dtype_check_enabled:
assert torch.get_default_dtype() == torch.float

# attribute may not be defined, per above
if hasattr(self, '_prev_grad_state'):
torch.set_grad_enabled(self._prev_grad_state)

@staticmethod
def _make_crow_indices(n_rows, n_cols, nnz,
*, device, dtype, random=True):
Expand Down

0 comments on commit d11e44c

Please sign in to comment.