diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 4f2fa3d20eeb..409eb612dc88 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -1696,7 +1696,6 @@ def test_aten_glu_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) - @unittest.skip def test_aten_grid_sampler_2d_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), @@ -3080,7 +3079,6 @@ def test_aten_reciprocal_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) - @unittest.skip def test_aten_reflection_pad1d_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3092,7 +3090,6 @@ def test_aten_reflection_pad1d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) - @unittest.skip def test_aten_reflection_pad1d_1(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3130,10 +3127,9 @@ def test_aten_reflection_pad2d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) - @unittest.skip def test_aten_reflection_pad3d_0(self): args = ( - torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float32), + torch.randn((3, 3, 3, 3, 3)).to(torch.float32), [ 1, 2, @@ -3146,10 +3142,9 @@ def test_aten_reflection_pad3d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) - @unittest.skip def test_aten_reflection_pad3d_1(self): args = ( - torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float16), + torch.randn((3, 3, 3, 3, 3)).to(torch.float16), [ 1, 2, @@ -3162,10 +3157,9 @@ def test_aten_reflection_pad3d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) - @unittest.skip def test_aten_reflection_pad3d_2(self): args = ( - torch.randint(0, 10, (3, 3, 3, 3, 3, 3)).to(torch.int32), + torch.randint(0, 10, (3, 3, 3, 3, 3)).to(torch.int32), [ 1, 2, diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index ba37dad6af3b..7227c9b6a586 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -20,6 +20,7 @@ import torch_xla.experimental.quantized import torch._dynamo as torchdynamo from torch.utils import _pytree as pytree +from torch._decomp import get_decompositions from typing import Tuple, Type, Callable @@ -288,11 +289,15 @@ def _extract_input_args(exported_model, options): return copy.deepcopy(args) +_extra_decompositions = get_decompositions([torch.ops.aten.grid_sampler_2d]) + + def _exported_program_to_stablehlo_bundle(exported_model, options) -> StableHLOModelBundle: if options is None: options = StableHLOExportOptions() exported_model = exported_model.run_decompositions() + exported_model = exported_model.run_decompositions(_extra_decompositions) input_args = _extract_input_args(exported_model, options) device = xm.xla_device()