diff --git a/test/test_models.py b/test/test_models.py index 150b813b0cb..ef0a5d0260f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,14 +133,7 @@ def get_export_import_copy(m): with freeze_rng_state(): results_from_imported = m_import(*args) tol = 3e-4 - try: - torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) - except ValueError: - # custom check for the models that return named tuples: - # we compare field by field while ignoring None as assert_close can't handle None - for a, b in zip(results, results_from_imported): - if a is not None: - torch.testing.assert_close(a, b, atol=tol, rtol=tol) + torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1" if not TEST_WITH_SLOW or skip: