diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index ee719f9bc579..936e5ccd601d 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -231,21 +231,27 @@ def test_simple_model(self): res_xla_dynamo = self.fn_simple_dynamo(xla_input) self.assertIn('xla::nll_loss_backward', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu())) + self.assertTrue( + torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) # verifiy that tracing is skipped in following runs xla_input.grad = None met.clear_counters() res_xla_dynamo_2 = self.fn_simple_dynamo(xla_input) self.assertNotIn('xla::nll_loss_backward', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo_2.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu())) + self.assertTrue( + torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) # verify that dynamo can handle different inputs input.grad = None xla_input.grad = None res_xla_dynamo_3 = self.fn_simple_dynamo(xla_input * 2) res_cpu_3 = self.fn_simple(input * 2) self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu())) + self.assertTrue( + torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) def test_resnet18(self): torch._dynamo.reset() @@ -343,7 +349,8 @@ def test_simple_model(self): res_cpu = self.fn_simple(input, optimizer) res_xla_dynamo = self.fn_simple_dynamo(xla_input, xla_optimizer) assert torch.allclose(res_cpu, res_xla_dynamo.cpu()) - assert torch.allclose(input.grad, xla_input.grad.cpu()) + assert torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-04, atol=1e-04) assert torch.allclose(input, xla_input.cpu()) def test_resnet18(self):