Skip to content

Commit

Permalink
[ONNX] Support aten::scatter_reduce
Browse files Browse the repository at this point in the history
ghstack-source-id: 80d9fd756f7554b27d76b1ad1f6cede7b218b92f
Pull Request resolved: #102048
  • Loading branch information
titaiwangms committed May 25, 2023
1 parent f3e42f1 commit c2996d4
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip_install \
transformers==4.25.1

# TODO: change this when onnx-script is on testPypi
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@bf502680231e4b134a71f74e812c84ddd7efffbe"
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@68adea42fb9b7353148e7ab289b76f9b89890e1c"

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
47 changes: 47 additions & 0 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
"nn.functional.embedding",
"nn.functional.nll_loss",
# "nn.functional.scaled_dot_product_attention" non-deterministic
"scatter_add",
"scatter_reduce",
"unflatten",
]
)
Expand Down Expand Up @@ -451,6 +453,40 @@
"nn.functional.embedding",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.embedding_renorm.default"),
),
xfail(
"scatter_add",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
),
xfail(
"scatter_reduce",
variant_name="sum",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
),
xfail(
"scatter_reduce",
variant_name="prod",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
),
xfail(
"scatter_reduce",
variant_name="amin",
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
),
xfail(
"scatter_reduce",
variant_name="amax",
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
),
xfail(
"scatter_reduce",
variant_name="mean",
reason="ONNX doesn't support reduce='mean' option",
),
xfail(
"unflatten", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Unflatten")
Expand Down Expand Up @@ -522,6 +558,17 @@
"string in reduction kwarg: https://github.com/microsoft/onnxscript/issues/726"
),
),
xfail(
"scatter_add",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch",
),
skip(
"scatter_reduce",
# ONNX has not include_self parameter and default is include_self=True mode
matcher=lambda sample: sample.kwargs.get("include_self") is False,
reason="ONNX does't support include_self=False option",
),
xfail(
"unflatten",
reason="Logic not implemented for size 0 inputs in op.Reshape",
Expand Down
47 changes: 47 additions & 0 deletions test/onnx/test_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
"logical_not",
"nn.functional.scaled_dot_product_attention",
"repeat",
# "scatter_add", # TODO: enable after fixing https://github.com/pytorch/pytorch/issues/102211
# "scatter_reduce", # TODO: enable after fixing https://github.com/pytorch/pytorch/issues/102211
"sqrt",
"stft",
"t",
Expand Down Expand Up @@ -98,6 +100,45 @@
),
skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
skip("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
skip("scatter_reduce", variant_name="amin", opsets=[onnx_test_common.opsets_before(16)],
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
skip("scatter_reduce", variant_name="amax", opsets=[onnx_test_common.opsets_before(16)],
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
skip("scatter_reduce", variant_name="prod", opsets=[onnx_test_common.opsets_before(16)],
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
xfail("scatter_reduce", variant_name="mean",
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction=mean")),
skip("scatter_reduce", variant_name="sum", opsets=[onnx_test_common.opsets_before(16)],
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
xfail(
"scatter_reduce",
variant_name="sum",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
),
xfail(
"scatter_reduce",
variant_name="prod",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
),
xfail(
"scatter_reduce",
variant_name="amin",
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
),
xfail(
"scatter_reduce",
variant_name="amax",
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
),
xfail(
"scatter_reduce",
variant_name="mean",
reason="ONNX doesn't support reduce='mean' option",
),
skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
Expand All @@ -116,6 +157,12 @@
reason="Empty repeats value leads to an invalid graph",
matcher=lambda sample: not sample.args[0],
),
skip(
"scatter_reduce",
# ONNX has not include_self parameter and default is include_self=True mode
matcher=lambda sample: sample.kwargs.get("include_self") is False,
reason="ONNX does't support include_self=False option",
),
skip(
"stft",
reason="ONNX STFT does not support complex results",
Expand Down
43 changes: 43 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4005,6 +4005,49 @@ def forward(self, input, indices, src):
dynamic_axes={"indices": {0: "a", 1: "b"}, "src": {0: "c", 1: "d"}},
)

@skipIfUnsupportedMinOpsetVersion(16)
def test_scatter_reduce(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, index, input):
y_max = input.scatter_reduce(0, index, x, reduce="amax")
y_sum = input.scatter_reduce(0, index, x, reduce="sum")
y_min = input.scatter_reduce(0, index, x, reduce="amin")
y_mul = input.scatter_reduce(0, index, x, reduce="prod")
return y_max, y_sum, y_min, y_mul

model = Model()
model.eval()

src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
index = torch.tensor([0, 1, 0, 1, 2, 1])
input = torch.tensor([1.0, 2.0, 3.0, 8.0])

self.run_test(model, (src, index, input))

@skipIfUnsupportedMinOpsetVersion(16)
def test_scatter_reduce_self_rank_zero(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, index, input):
y_max = input.scatter_reduce(0, index, x, reduce="amax")
y_sum = input.scatter_reduce(0, index, x, reduce="sum")
y_min = input.scatter_reduce(0, index, x, reduce="amin")
y_mul = input.scatter_reduce(0, index, x, reduce="prod")
return y_max, y_sum, y_min, y_mul

model = Model()
model.eval()

empty_tensor = torch.tensor([])
empty_idx = torch.tensor([], dtype=torch.int64)

self.run_test(model, (empty_tensor, empty_idx, empty_tensor))

@skipIfUnsupportedMinOpsetVersion(9)
def test_bucketize(self):
class BucketModel(torch.nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ c10::optional<::c10::SymbolicShape> ComputeShapeFromReshape(
uint64_t shape_ratio = 1;
std::unordered_map<int64_t, int64_t> sym_map;
for (const c10::ShapeSymbol& input_shape : input_shape_vector) {
if (input_shape.is_static()) {
// input_shape.static_size() could be zero when torch.tensor([]) is used.
if (input_shape.is_static() and input_shape.static_size() != 0) {
if (shape_ratio >=
std::numeric_limits<uint64_t>::max() / input_shape.static_size()) {
TORCH_WARN(
Expand Down
10 changes: 7 additions & 3 deletions torch/onnx/_internal/io_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def append_step(self, step: InputAdaptStep) -> None:
@_beartype.beartype
def apply(
self, *model_args, **model_kwargs
) -> Sequence[Union[int, float, bool, "torch.Tensor", None]]:
) -> Sequence[Union[int, float, bool, str, "torch.Tensor", None]]:
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
Args:
Expand Down Expand Up @@ -113,7 +113,7 @@ def append_step(self, step: OutputAdaptStep) -> None:
@_beartype.beartype
def apply(
self, model_outputs: Any
) -> Sequence[Union["torch.Tensor", int, float, bool]]:
) -> Sequence[Union["torch.Tensor", int, float, bool, str]]:
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
Args:
Expand Down Expand Up @@ -325,7 +325,11 @@ def apply(
"""
assert not model_kwargs
return (
tuple(arg for arg in model_args if not isinstance(arg, (int, float, bool))),
tuple(
arg
for arg in model_args
if not isinstance(arg, (int, float, bool, str))
),
{},
)

Expand Down
72 changes: 71 additions & 1 deletion torch/onnx/symbolic_opset16.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
GRID_SAMPLE_INTERPOLATION_MODES,
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import _type_utils, symbolic_helper
from torch.onnx import _type_utils, errors, symbolic_helper, utils
from torch.onnx._internal import _beartype, jit_utils, registration

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
Expand Down Expand Up @@ -115,3 +115,73 @@ def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
axis_i=dim,
reduction_s="add",
)


@_onnx_symbolic("aten::scatter_reduce")
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
@_beartype.beartype
def scatter_reduce(
g: jit_utils.GraphContext,
self: torch._C.Value,
dim: int,
index: torch._C.Value,
src: torch._C.Value,
reduce: str,
include_self: bool,
):
if reduce == "mean":
raise errors.OnnxExporterError(
"ONNX does not support mean reduction for scatter_reduce"
)
if not include_self:
raise errors.OnnxExporterError(
"ONNX does not support include_self=False for scatter_reduce"
)

reduce_mode = { # convert torch string name to onnx string name
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
"sum": "add",
"prod": "mul",
"amin": "min",
"amax": "max",
}
onnx_reduce = reduce_mode[reduce]

self_rank = g.op("Size", g.op("Shape", self))

# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0)
self_rank_is_zero = g.op(
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
)
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=3
)
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))

self_reshape = if_context.op("Reshape", self, neg_1)
utils._add_output_to_block(if_context.block, self_reshape)
index_reshape = if_context.op("Reshape", index, neg_1)
utils._add_output_to_block(if_context.block, index_reshape)
src_reshape = if_context.op("Reshape", src, neg_1)
utils._add_output_to_block(if_context.block, src_reshape)

self_identity = else_context.op("Identity", self)
utils._add_output_to_block(else_context.block, self_identity)
index_identitye = else_context.op("Identity", index)
utils._add_output_to_block(else_context.block, index_identitye)
src_identity = else_context.op("Identity", src)
utils._add_output_to_block(else_context.block, src_identity)

result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)

# if self_rank == 0:
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", self_rank_is_zero, n_blocks=2, outputs=1
)
result_squeezed = if_context.op("Squeeze", result)
utils._add_output_to_block(if_context.block, result_squeezed)
result_identity = else_context.op("Identity", result)
utils._add_output_to_block(else_context.block, result_identity)
result_final = if_op.node().output()

return result_final

0 comments on commit c2996d4

Please sign in to comment.