diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index d7ad7e5eeeec0..301e85f4a23e6 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -36,19 +36,19 @@ def tearDown(self): @unittest.skipIf(not TEST_CUDA, "No cuda") def test_generic_jit_autocast(self): @torch.jit.script - def fn_cuda(a, b): + def fn_cuda_autocast(a, b): with autocast(): x = torch.mm(a, b) y = torch.sum(x) return x, y @torch.jit.script - def fn_generic(a, b): + def fn_generic_autocast(a, b): with torch.amp.autocast(device_type='cpu'): x = torch.mm(a, b) y = torch.sum(x) return x, y - self.assertEqual(fn_cuda(self.a_fp32, self.b_fp32), fn_generic(self.a_fp32, self.b_fp32)) + self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32)) @unittest.skipIf(not TEST_CUDA, "No cuda") def test_minimal(self):