From a8d4c29764af7f2effd2391b42af420784d9d46d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 22 Nov 2021 11:45:11 +0100 Subject: [PATCH] remove custom code for model output comparison --- test/test_models.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 150b813b0cb..ef0a5d0260f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,14 +133,7 @@ def get_export_import_copy(m): with freeze_rng_state(): results_from_imported = m_import(*args) tol = 3e-4 - try: - torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) - except ValueError: - # custom check for the models that return named tuples: - # we compare field by field while ignoring None as assert_close can't handle None - for a, b in zip(results, results_from_imported): - if a is not None: - torch.testing.assert_close(a, b, atol=tol, rtol=tol) + torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1" if not TEST_WITH_SLOW or skip: