Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,6 @@ def test_aten_glu_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.glu, args, kwargs)

@unittest.skip
def test_aten_grid_sampler_2d_0(self):
args = (
torch.randn((1, 3, 2, 10)).to(torch.float32),
Expand Down Expand Up @@ -3080,7 +3079,6 @@ def test_aten_reciprocal_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs)

@unittest.skip
def test_aten_reflection_pad1d_0(self):
args = (
torch.randn((10, 10)).to(torch.float32),
Expand All @@ -3092,7 +3090,6 @@ def test_aten_reflection_pad1d_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad1d_1(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
Expand Down Expand Up @@ -3130,10 +3127,9 @@ def test_aten_reflection_pad2d_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad3d_0(self):
args = (
torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float32),
torch.randn((3, 3, 3, 3, 3)).to(torch.float32),
[
1,
2,
Expand All @@ -3146,10 +3142,9 @@ def test_aten_reflection_pad3d_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad3d_1(self):
args = (
torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float16),
torch.randn((3, 3, 3, 3, 3)).to(torch.float16),
[
1,
2,
Expand All @@ -3162,10 +3157,9 @@ def test_aten_reflection_pad3d_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad3d_2(self):
args = (
torch.randint(0, 10, (3, 3, 3, 3, 3, 3)).to(torch.int32),
torch.randint(0, 10, (3, 3, 3, 3, 3)).to(torch.int32),
[
1,
2,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch_xla.experimental.quantized
import torch._dynamo as torchdynamo
from torch.utils import _pytree as pytree
from torch._decomp import get_decompositions

from typing import Tuple, Type, Callable

Expand Down Expand Up @@ -288,11 +289,15 @@ def _extract_input_args(exported_model, options):
return copy.deepcopy(args)


_extra_decompositions = get_decompositions([torch.ops.aten.grid_sampler_2d])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, what is this needed for?



def _exported_program_to_stablehlo_bundle(exported_model,
options) -> StableHLOModelBundle:
if options is None:
options = StableHLOExportOptions()
exported_model = exported_model.run_decompositions()
exported_model = exported_model.run_decompositions(_extra_decompositions)
input_args = _extract_input_args(exported_model, options)

device = xm.xla_device()
Expand Down