Skip to content

Commit

Permalink
enable embedding_bag which contain reducesum op
Browse files Browse the repository at this point in the history
  • Loading branch information
hwangdeyu committed Jan 14, 2021
1 parent d4a35e5 commit 4061597
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
3 changes: 0 additions & 3 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4312,7 +4312,6 @@ def forward(self, input):

@disableScriptTest() # error in propagate as assign input shape
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag(self):
model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True)
input = torch.randint(10, (7,))
Expand All @@ -4329,7 +4328,6 @@ def test_embedding_bag(self):
self.run_test(model, (input))

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_1d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, offset, weights):
Expand All @@ -4344,7 +4342,6 @@ def forward(self, embedding_matrix, input, offset, weights):
self.run_test(model, (embedding_matrix, x, offset, w))

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_2d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, weights):
Expand Down
73 changes: 71 additions & 2 deletions torch/onnx/symbolic_opset13.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# see Note [Edit Symbolic Files] in symbolic_helper.py

# This file exports ONNX ops for opset 13
from sys import maxsize
from torch.onnx.symbolic_helper import _block_list_in_opset
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block

block_listed_operators = ['embedding_bag']
block_listed_operators = ['']

for block_listed_op in block_listed_operators:
vars()[block_listed_op] = _block_list_in_opset(block_listed_op)
Expand Down Expand Up @@ -140,4 +142,71 @@ def reduce_dim(g, self, dim, keepdim, dtype):
return reduce_nodim, reduce_dim
return reduce

sum = _reduce_with_dtype('ReduceSum', 'sum')
sum = _reduce_with_dtype('ReduceSum', 'sum')


@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v', 'i')
def embedding_bag(g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset):
if scale_grad_by_freq and sym_help._training_mode:
return sym_help._onnx_unsupported('embedding_bag with scale_grad_by_freq for training mode')

loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op("Cast", loop_condition, to_i=9)
zero = g.op("Constant", value_t=torch.tensor([0]))

indices_len = sym_help._unsqueeze_helper(g,
sym_help._size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))),
[0])
if not include_last_offset:
offsets = [offsets, indices_len]
offsets = g.op("Concat", *offsets, axis_i=0)

# Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
# offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
# The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
offsets_starts = sym_help._slice_helper(g, offsets, axes=[0], starts=[0], ends=[maxsize], steps=[1])
offsets_ends = sym_help._slice_helper(g, offsets, axes=[0], starts=[1], ends=[maxsize], steps=[1])

loop_len = sym_help._size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0)))
loop = g.op("Loop", loop_len, loop_condition)

loop_block = _add_block(loop.node())
block_input_iter = _add_input_to_block(loop_block)
cond = _add_input_to_block(loop_block)

indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0)
indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0)
indices_start = sym_help._unsqueeze_helper(loop_block, indices_start, [0])
indices_end = sym_help._unsqueeze_helper(loop_block, indices_end, [0])

indices_row = loop_block.op("Slice", indices, indices_start, indices_end, zero)
embeddings = loop_block.op("Gather", embedding_matrix, indices_row, axis_i=0)
if not sym_help._is_none(per_sample_weights):
per_sample_weights_row = loop_block.op("Slice", per_sample_weights,
indices_start,
indices_end,
zero)
per_sample_weights_row = sym_help._unsqueeze_helper(loop_block, per_sample_weights_row, [1])
embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = loop_block.op("ReduceSum", embeddings, zero, keepdims_i=0)
elif mode == 1:
embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
embeddings = loop_block.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)

cond_out = loop_block.op("Cast", loop_condition, to_i=9)
_add_output_to_block(loop_block, cond_out)
_add_output_to_block(loop_block, embeddings)

# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return loop.node().output(), None, None, None

0 comments on commit 4061597

Please sign in to comment.