From dbcbad334c834f9b5eee2eee40782260e548665b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 29 Jun 2020 16:09:52 +0200 Subject: [PATCH] Fix torchhub due to numerical changes in torch.sum (#2361) --- test/test_hub.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/test_hub.py b/test/test_hub.py index 4ae9e51021b..29ae90014d1 100644 --- a/test/test_hub.py +++ b/test/test_hub.py @@ -13,7 +13,7 @@ def sum_of_model_parameters(model): return s -SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.99609375 +SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625 @unittest.skipIf('torchvision' in sys.modules, @@ -31,8 +31,9 @@ def test_load_from_github(self): 'resnet18', pretrained=True, progress=False) - self.assertEqual(sum_of_model_parameters(hub_model).item(), - SUM_OF_PRETRAINED_RESNET18_PARAMS) + self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(), + SUM_OF_PRETRAINED_RESNET18_PARAMS, + places=2) def test_set_dir(self): temp_dir = tempfile.gettempdir() @@ -42,8 +43,9 @@ def test_set_dir(self): 'resnet18', pretrained=True, progress=False) - self.assertEqual(sum_of_model_parameters(hub_model).item(), - SUM_OF_PRETRAINED_RESNET18_PARAMS) + self.assertAlmostEqual(sum_of_model_parameters(hub_model).item(), + SUM_OF_PRETRAINED_RESNET18_PARAMS, + places=2) self.assertTrue(os.path.exists(temp_dir + '/pytorch_vision_master')) shutil.rmtree(temp_dir + '/pytorch_vision_master')