From 4ae22d6554742d96d9dc633cb36064c9ad39e411 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 22 Jan 2024 18:27:18 +0000 Subject: [PATCH 1/2] Export no-longer accepts callable; wrap in nn.Module. --- test/test_core_aten_ops.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 7b7c7e1eae5..5e32cb5c881 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -42,6 +42,16 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): testcase.assertEqual(output1, output2) +class NNModWrapper(torch.nn.Module): + + def __init__(self, op): + super().__init__() + self._op = op + + def forward(self, *args, **kwargs): + return self._op(*args, **kwargs) + + def run_export_and_compare(testcase, func, args, @@ -62,7 +72,7 @@ def run_export_and_compare(testcase, diff_output( testcase, res, res_xla, atol=atol, rtol=rtol, equal_nan=equal_nan) with testcase.subTest('can_export'): - exported = torch.export.export(func, args, kwargs) + exported = torch.export.export(NNModWrapper(func), args, kwargs) with testcase.subTest('can_convert_to_stablehlo'): shlo = exported_program_to_stablehlo(exported) with testcase.subTest('stablehlo_can_run'): @@ -4596,9 +4606,6 @@ def test_aten_upsample_nearest2d_0(self): run_export_and_compare(self, torch.ops.aten.upsample_nearest2d, args, kwargs) - def correction_wrapper(self, input, correction): - return torch.ops.aten.var.correction(input, correction=correction) - def test_aten_var_correction_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() @@ -4612,12 +4619,12 @@ def test_aten_var_correction_1(self): def test_aten_var_correction_2(self): args = (torch.randn((10, 10)).to(torch.float32), 0) kwargs = dict() - run_export_and_compare(self, self.correction_wrapper, args, kwargs) + run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) def test_aten_var_correction_3(self): args = (torch.randn((10, 10)).to(torch.float16), 0) kwargs = dict() - run_export_and_compare(self, self.correction_wrapper, args, kwargs) + run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) def test_aten_view_0(self): args = ( From 4691231fb0340a719715bb9cecae412aa8ef0f96 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 22 Jan 2024 23:02:59 +0000 Subject: [PATCH 2/2] Enable test --- test/run_tests.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index 3c5fbb1809b..1553d53e409 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -214,8 +214,7 @@ function run_xla_op_tests3 { run_torchrun "$CDIR/pjrt/test_torchrun.py" run_test "$CDIR/test_persistent_cache.py" # NOTE: this line below is testing export and don't care about GPU - # TODO(qihqi): Enable after fix - # PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py" + PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py" } #######################################################################################