Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Update Reducesum operator for opset 13 #50532

Merged
merged 20 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 0 additions & 4 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2754,7 +2754,6 @@ def forward(self, input):
x = torch.randn(4, 5, dtype=torch.float)
self.run_test(ReducedOpModule(), x)

@skipIfUnsupportedOpsetVersion([13])
def test_reduced_sum(self):
return self._test_reduced_ops(op=torch.sum)

Expand Down Expand Up @@ -4314,7 +4313,6 @@ def forward(self, input):

@disableScriptTest() # error in propagate as assign input shape
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag(self):
model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True)
input = torch.randint(10, (7,))
Expand All @@ -4331,7 +4329,6 @@ def test_embedding_bag(self):
self.run_test(model, (input))

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_1d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, offset, weights):
Expand All @@ -4346,7 +4343,6 @@ def forward(self, embedding_matrix, input, offset, weights):
self.run_test(model, (embedding_matrix, x, offset, w))

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_2d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, weights):
Expand Down
11 changes: 11 additions & 0 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,17 @@ def _squeeze_helper(g, input, axes_i):
else:
return g.op("Squeeze", input, axes_i=axes_i)

def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
keepdims_i = _maybe_get_const(keepdims_i, 'i')
if _export_onnx_opset_version >= 13:
if axes_i:
if not _is_value(axes_i):
axes_i = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
return g.op("ReduceSum", input, axes_i, keepdims_i=keepdims_i, noop_with_empty_axes_i=noop_with_empty_axes_i)
return g.op("ReduceSum", input, keepdims_i=keepdims_i, noop_with_empty_axes_i=noop_with_empty_axes_i)
else:
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)

def _interpolate_size_to_scales(g, input, output_size, dim):
output_size = _maybe_get_const(output_size, 'is')
if _is_value(output_size):
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset10.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def embedding_bag(g,
per_sample_weights_row = sym_help._unsqueeze_helper(g, per_sample_weights_row, [1])
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = g.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0)
embeddings = sym_help._reducesum_helper(g, embeddings, axes_i=[0], keepdims_i=0)
elif mode == 1:
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def embedding_bag(g,
per_sample_weights_row = sym_help._unsqueeze_helper(loop_block, per_sample_weights_row, [1])
embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = loop_block.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0)
embeddings = sym_help._reducesum_helper(loop_block, embeddings, axes_i=[0], keepdims_i=0)
elif mode == 1:
embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
Expand Down
46 changes: 40 additions & 6 deletions torch/onnx/symbolic_opset13.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# see Note [Edit Symbolic Files] in symbolic_helper.py

# This file exports ONNX ops for opset 13
from torch.onnx.symbolic_helper import _block_list_in_opset
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input

block_listed_operators = ['embedding_bag']

for block_listed_op in block_listed_operators:
vars()[block_listed_op] = _block_list_in_opset(block_listed_op)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

# This file exports ONNX ops for opset 13


@parse_args('v', 'i', 'none')
Expand Down Expand Up @@ -38,7 +39,7 @@ def frobenius_norm(g, self, dim=None, keepdim=False):
if not sym_help._is_value(dim_val) and len(dim_val) == 0:
return g.op("ReduceL2", self, keepdims_i=0)
sqr = g.op('Mul', self, self)
sumsqr = g.op('ReduceSum', sqr, dim, keepdims_i=keepdim)
sumsqr = sym_help._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
return g.op('Sqrt', sumsqr)


Expand Down Expand Up @@ -108,3 +109,36 @@ def unbind(g, self, dim=0, _outputs=None):
def glu(g, input, dim):
first, second = g.op('Split', input, dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))


def _reduce_op_symbolic(onnx_op_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes for sum and embedding_bag look good. Thanks.

def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
return g.op(onnx_op_name, self, keepdims_i=0)
else:
keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
return symbolic

def _reduce_with_dtype(onnx_op, name):
symbolic = _reduce_op_symbolic(onnx_op)

@overload_by_arg_count
def reduce(g, *args, **kwargs):
@parse_args('v', 'none')
def reduce_nodim(g, self, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self)

@parse_args('v', 'v', 'i', 'none')
def reduce_dim(g, self, dim, keepdim, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self, dim, keepdim)
return reduce_nodim, reduce_dim
return reduce

sum = _reduce_with_dtype('ReduceSum', 'sum')
8 changes: 4 additions & 4 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def softmax(g, input, dim, dtype=None):
input = g.op('Sub', input, g.op('ReduceMax', input, axes_i=[dim], keepdims_i=1))

exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
sum = sym_help._reducesum_helper(g, exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
Expand Down Expand Up @@ -2383,7 +2383,7 @@ def gather(g, self, dim, index, sparse_grad=False):
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
index = g.op("Cast", g.op("OneHot", index, depth, values, axis_i=dim), to_i=sym_help.cast_pytorch_to_onnx[dtype])
mul = g.op("Mul", sym_help._unsqueeze_helper(g, self, [dim + 1]), index)
return g.op("ReduceSum", mul, axes_i=[dim], keepdims_i=0)
return sym_help._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0)


@parse_args('v', 'is', 'b', 'i')
Expand Down Expand Up @@ -2639,7 +2639,7 @@ def try_mask_to_index(index):
@parse_args('v', 'is', 'i')
def frobenius_norm(g, self, dim=None, keepdim=False):
sqr = g.op('Mul', self, self)
sumsqr = g.op('ReduceSum', sqr, axes_i=dim, keepdims_i=keepdim)
sumsqr = sym_help._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim)
return g.op('Sqrt', sumsqr)


Expand Down Expand Up @@ -2805,7 +2805,7 @@ def kl_div(g, input, target, reduction, log_target):
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return g.op("ReduceSum", output, keepdims_i=0)
return sym_help._reducesum_helper(g, output, keepdims_i=0)
else:
return sym_help._onnx_unsupported("kl_div with reduction other than none, mean, or sum. Please open a bug to "
"request ONNX export support for the missing reduction type.")
Expand Down