From 6444d623eb7ea440c44f0419cc59f18e0281bb82 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:25:52 -0700 Subject: [PATCH 1/3] tweak `atol` and `rtol` for `DynamoTrainingBasicTest.test_simple_model` --- test/dynamo/test_dynamo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index ee719f9bc579..1a113d235c58 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -343,7 +343,7 @@ def test_simple_model(self): res_cpu = self.fn_simple(input, optimizer) res_xla_dynamo = self.fn_simple_dynamo(xla_input, xla_optimizer) assert torch.allclose(res_cpu, res_xla_dynamo.cpu()) - assert torch.allclose(input.grad, xla_input.grad.cpu()) + assert torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-04, atol=1e-04) assert torch.allclose(input, xla_input.cpu()) def test_resnet18(self): From 2077794b8ab12514914174f63528fa4a3f02c936 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:27:39 -0700 Subject: [PATCH 2/3] tweak the atol and rtol for `DynamoTrainingOptimizerTest.test_simple_model` --- test/dynamo/test_dynamo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 1a113d235c58..e025299801d9 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -231,21 +231,21 @@ def test_simple_model(self): res_xla_dynamo = self.fn_simple_dynamo(xla_input) self.assertIn('xla::nll_loss_backward', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu())) + self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) # verifiy that tracing is skipped in following runs xla_input.grad = None met.clear_counters() res_xla_dynamo_2 = self.fn_simple_dynamo(xla_input) self.assertNotIn('xla::nll_loss_backward', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo_2.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu())) + self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) # verify that dynamo can handle different inputs input.grad = None xla_input.grad = None res_xla_dynamo_3 = self.fn_simple_dynamo(xla_input * 2) res_cpu_3 = self.fn_simple(input * 2) self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu())) + self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) def test_resnet18(self): torch._dynamo.reset() From 72fa607d310fce67545089118eeddb80e537994a Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:55:27 -0700 Subject: [PATCH 3/3] format --- test/dynamo/test_dynamo.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index e025299801d9..936e5ccd601d 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -231,21 +231,27 @@ def test_simple_model(self): res_xla_dynamo = self.fn_simple_dynamo(xla_input) self.assertIn('xla::nll_loss_backward', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) + self.assertTrue( + torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) # verifiy that tracing is skipped in following runs xla_input.grad = None met.clear_counters() res_xla_dynamo_2 = self.fn_simple_dynamo(xla_input) self.assertNotIn('xla::nll_loss_backward', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo_2.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) + self.assertTrue( + torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) # verify that dynamo can handle different inputs input.grad = None xla_input.grad = None res_xla_dynamo_3 = self.fn_simple_dynamo(xla_input * 2) res_cpu_3 = self.fn_simple(input * 2) self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu())) - self.assertTrue(torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) + self.assertTrue( + torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) def test_resnet18(self): torch._dynamo.reset() @@ -343,7 +349,8 @@ def test_simple_model(self): res_cpu = self.fn_simple(input, optimizer) res_xla_dynamo = self.fn_simple_dynamo(xla_input, xla_optimizer) assert torch.allclose(res_cpu, res_xla_dynamo.cpu()) - assert torch.allclose(input.grad, xla_input.grad.cpu(), rtol=1e-04, atol=1e-04) + assert torch.allclose( + input.grad, xla_input.grad.cpu(), rtol=1e-04, atol=1e-04) assert torch.allclose(input, xla_input.cpu()) def test_resnet18(self):