Skip to content

Commit

Permalink
[ONNX] Update Reducesum operator for opset 13 (#50532)
Browse files Browse the repository at this point in the history
* udpate symbolic for squeeze/unsqueeze

* update c++ unsqueeze/squeeze creation

* clang format

* enable tests

* clang format

* remove prints

* remove magic number

* add helper function

* fix build issue

* update opset9 symbolic with helper function

* fix utility test

* fix prim_fallthrough opset skip

* enable reducesum opset 13

* enable embedding_bag which contain reducesum op

* add ReduceSum helper

* remove block_listed_operators

* remove local test code

* remove embedding_bag() in opset13 file

* remove unuse import

Co-authored-by: BowenBao <bowbao@microsoft.com>
Co-authored-by: hwangdeyu <deyhuang@qq.com>

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Jan 21, 2021
1 parent a1f2867 commit 057a126
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 16 deletions.
4 changes: 0 additions & 4 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2754,7 +2754,6 @@ def forward(self, input):
x = torch.randn(4, 5, dtype=torch.float)
self.run_test(ReducedOpModule(), x)

@skipIfUnsupportedOpsetVersion([13])
def test_reduced_sum(self):
return self._test_reduced_ops(op=torch.sum)

Expand Down Expand Up @@ -4314,7 +4313,6 @@ def forward(self, input):

@disableScriptTest() # error in propagate as assign input shape
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag(self):
model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True)
input = torch.randint(10, (7,))
Expand All @@ -4331,7 +4329,6 @@ def test_embedding_bag(self):
self.run_test(model, (input))

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_1d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, offset, weights):
Expand All @@ -4346,7 +4343,6 @@ def forward(self, embedding_matrix, input, offset, weights):
self.run_test(model, (embedding_matrix, x, offset, w))

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_2d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, weights):
Expand Down
11 changes: 11 additions & 0 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,17 @@ def _squeeze_helper(g, input, axes_i):
else:
return g.op("Squeeze", input, axes_i=axes_i)

def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
keepdims_i = _maybe_get_const(keepdims_i, 'i')
if _export_onnx_opset_version >= 13:
if axes_i:
if not _is_value(axes_i):
axes_i = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
return g.op("ReduceSum", input, axes_i, keepdims_i=keepdims_i, noop_with_empty_axes_i=noop_with_empty_axes_i)
return g.op("ReduceSum", input, keepdims_i=keepdims_i, noop_with_empty_axes_i=noop_with_empty_axes_i)
else:
return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)

def _interpolate_size_to_scales(g, input, output_size, dim):
output_size = _maybe_get_const(output_size, 'is')
if _is_value(output_size):
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset10.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def embedding_bag(g,
per_sample_weights_row = sym_help._unsqueeze_helper(g, per_sample_weights_row, [1])
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = g.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0)
embeddings = sym_help._reducesum_helper(g, embeddings, axes_i=[0], keepdims_i=0)
elif mode == 1:
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def embedding_bag(g,
per_sample_weights_row = sym_help._unsqueeze_helper(loop_block, per_sample_weights_row, [1])
embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = loop_block.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0)
embeddings = sym_help._reducesum_helper(loop_block, embeddings, axes_i=[0], keepdims_i=0)
elif mode == 1:
embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
Expand Down
46 changes: 40 additions & 6 deletions torch/onnx/symbolic_opset13.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# see Note [Edit Symbolic Files] in symbolic_helper.py

# This file exports ONNX ops for opset 13
from torch.onnx.symbolic_helper import _block_list_in_opset
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input

block_listed_operators = ['embedding_bag']

for block_listed_op in block_listed_operators:
vars()[block_listed_op] = _block_list_in_opset(block_listed_op)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

# This file exports ONNX ops for opset 13


@parse_args('v', 'i', 'none')
Expand Down Expand Up @@ -38,7 +39,7 @@ def frobenius_norm(g, self, dim=None, keepdim=False):
if not sym_help._is_value(dim_val) and len(dim_val) == 0:
return g.op("ReduceL2", self, keepdims_i=0)
sqr = g.op('Mul', self, self)
sumsqr = g.op('ReduceSum', sqr, dim, keepdims_i=keepdim)
sumsqr = sym_help._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
return g.op('Sqrt', sumsqr)


Expand Down Expand Up @@ -108,3 +109,36 @@ def unbind(g, self, dim=0, _outputs=None):
def glu(g, input, dim):
first, second = g.op('Split', input, dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))


def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
return g.op(onnx_op_name, self, keepdims_i=0)
else:
keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
return symbolic

def _reduce_with_dtype(onnx_op, name):
symbolic = _reduce_op_symbolic(onnx_op)

@overload_by_arg_count
def reduce(g, *args, **kwargs):
@parse_args('v', 'none')
def reduce_nodim(g, self, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self)

@parse_args('v', 'v', 'i', 'none')
def reduce_dim(g, self, dim, keepdim, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self, dim, keepdim)
return reduce_nodim, reduce_dim
return reduce

sum = _reduce_with_dtype('ReduceSum', 'sum')
8 changes: 4 additions & 4 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def softmax(g, input, dim, dtype=None):
input = g.op('Sub', input, g.op('ReduceMax', input, axes_i=[dim], keepdims_i=1))

exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
sum = sym_help._reducesum_helper(g, exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
Expand Down Expand Up @@ -2383,7 +2383,7 @@ def gather(g, self, dim, index, sparse_grad=False):
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
index = g.op("Cast", g.op("OneHot", index, depth, values, axis_i=dim), to_i=sym_help.cast_pytorch_to_onnx[dtype])
mul = g.op("Mul", sym_help._unsqueeze_helper(g, self, [dim + 1]), index)
return g.op("ReduceSum", mul, axes_i=[dim], keepdims_i=0)
return sym_help._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0)


@parse_args('v', 'is', 'b', 'i')
Expand Down Expand Up @@ -2639,7 +2639,7 @@ def try_mask_to_index(index):
@parse_args('v', 'is', 'i')
def frobenius_norm(g, self, dim=None, keepdim=False):
sqr = g.op('Mul', self, self)
sumsqr = g.op('ReduceSum', sqr, axes_i=dim, keepdims_i=keepdim)
sumsqr = sym_help._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim)
return g.op('Sqrt', sumsqr)


Expand Down Expand Up @@ -2805,7 +2805,7 @@ def kl_div(g, input, target, reduction, log_target):
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return g.op("ReduceSum", output, keepdims_i=0)
return sym_help._reducesum_helper(g, output, keepdims_i=0)
else:
return sym_help._onnx_unsupported("kl_div with reduction other than none, mean, or sum. Please open a bug to "
"request ONNX export support for the missing reduction type.")
Expand Down

0 comments on commit 057a126

Please sign in to comment.