diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 42d44dfdbd9..cd083076820 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -24,7 +24,7 @@ class Tester(TransformsTester): def setUp(self): self.device = "cpu" - def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs): + def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs): transformed_batch = fn(batch_tensors, **fn_kwargs) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] @@ -34,7 +34,7 @@ def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs): scripted_fn = torch.jit.script(fn) # scriptable function test s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) - self.assertTrue(transformed_batch.allclose(s_transformed_batch)) + self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol)) def test_assert_image_tensor(self): shape = (100,) @@ -348,7 +348,7 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method atol = 1.0 self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg) - self._test_fn_on_batch(batch_tensors, fn, **config) + self._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config) def test_adjust_brightness(self): self._test_adjust_fn(