Skip to content

Commit

Permalink
[ONNX] Update repeat_interleave symbolic (#54312)
Browse files Browse the repository at this point in the history
Add implementation for cases when
- interleaving happens along dim which consist of dynamic axes
  • Loading branch information
shubhambhokare1 committed Apr 8, 2021
1 parent 91ea787 commit 7cd2fe4
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 1 deletion.
42 changes: 42 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -3823,6 +3823,48 @@ def forward(self, x):
x = torch.tensor([[1, 2], [3, 4]])
self.run_test(RepeatsDimsModel2(), (x,))

@skipIfUnsupportedMinOpsetVersion(11)
def test_dynamic_repeat_interleave(self):
class SingleDynamicModel(torch.nn.Module):
def forward(self, x):
repeats = torch.tensor(4)
return torch.repeat_interleave(x, repeats, dim=1)

x = torch.tensor([[1, 2, 4], [3, 4, 7]])
another_x = torch.tensor([[7, 8], [5, 6]])
self.run_test(SingleDynamicModel(), x, test_with_inputs=[another_x],
input_names=['input_1'], dynamic_axes={'input_1' : {1 : 'w'}})

class NegDynamicModel(torch.nn.Module):
def forward(self, x):
repeats = torch.tensor(4)
return torch.repeat_interleave(x, repeats, dim=-1)

x = torch.tensor([[1, 2, 4], [3, 4, 7]])
another_x = torch.tensor([[7, 8], [5, 6]])
self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x],
input_names=['input_1'], dynamic_axes={'input_1' : {1 : 'w'}})

class SingleDynamicModel2(torch.nn.Module):
def forward(self, x):
repeats = torch.tensor([4])
return torch.repeat_interleave(x, repeats, dim=0)

x = torch.tensor([[1, 2], [3, 4]])
another_x = torch.tensor([[7, 8], [5, 6]])
self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x],
input_names=['input_1'], dynamic_axes={'input_1' : {0 : 'h'}})

class AllDynamicModel(torch.nn.Module):
def forward(self, x):
repeats = torch.tensor([4])
return torch.repeat_interleave(x, repeats, dim=0)

x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]])
another_x = torch.tensor([[7, 8], [5, 6]])
self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x],
input_names=['input_1'], dynamic_axes={'input_1' : {0 : 'h', 1 : 'w'}})

def test_view(self):
class ViewModel(torch.nn.Module):
def forward(self, input):
Expand Down
103 changes: 103 additions & 0 deletions torch/onnx/symbolic_opset11.py
Expand Up @@ -837,3 +837,106 @@ def prim_ConstantChunk(g, self, chunks, dim):
res.append(g.op("Slice", self, start, end, axis))
start = end
return res

def repeat_interleave(g, self, repeats, dim=None):
from torch.onnx.symbolic_opset9 import reshape
input = self
final_dim = dim
# if dim is None flatten
# By default, use the flattened input array, and return a flat output array
if sym_help._is_none(dim):
input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1])))
dim = 0
else:
dim = sym_help._maybe_get_scalar(dim)

repeats_dim = sym_help._get_tensor_rank(repeats)
repeats_sizes = sym_help._get_tensor_sizes(repeats)
input_sizes = sym_help._get_tensor_sizes(input)
if repeats_dim is None:
raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown '
'repeats rank.')
if repeats_sizes is None:
raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown '
'repeats size.')
if input_sizes is None:
raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown '
'input size.')
# Handle cases where dim is negative
if dim < 0:
dim += len(input_sizes)

output_sizes = input_sizes.copy()
perm_i = [0]
for idx, input_size in enumerate(input_sizes):
perm_i.append(idx + 1)
if input_size is None:
output_sizes[idx], input_sizes[idx] = 0, -1
perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]

# Cases when repeats is a single value tensor and dim has unknown input size
if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
if not sym_help._is_tensor(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
reps = sym_help._size_helper(g, input, dim)
reps = unsqueeze(g, reps, 0)
repeats = g.op("Expand", repeats, reps)
# There are cases when the repeats are 1-d tensor with multiple repeats, but dim
# provided along one of the dynamic axes provided. A simple example would be
# input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
# Now, repeat interleaving can be performed in pytorch when the value of * matches
# with the number of elements in repeat, for example if * -> 2, number of repeats
# should be 2 as well.
else:
return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)

reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
value_t=torch.tensor([1], dtype=torch.long))
r_splits = split(g, repeats, reps_like, 0)
i_splits = split(g, input, reps_like, dim)

output_sizes[dim], input_sizes[dim] = -1, 1

# Create a loop to iterate over each value along the dimension
# and perform individual interleaving using the repeats tensor
# Loop is of the following pattern
# input (trip_count, cond)
# int trip_count = ...;
# bool cond = ...;
# for (int i=0; i < trip_count && cond; ++i) {
# cond = ...;
# }

# Loop conditions
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op("Cast", loop_condition, to_i=9)
loop_len = reps
loop = g.op("Loop", loop_len, loop_condition)

# Loop inputs
loop_block = _add_block(loop.node())
block_input_iter = _add_input_to_block(loop_block)
cond = _add_input_to_block(loop_block)

r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)

i_split = unsqueeze(loop_block, i_split, dim + 1)
r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
r_split,
loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
i_split = expand(loop_block, i_split, r_concat, None)
i_split = reshape(loop_block, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)))

# Loop outputs
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
_add_output_to_block(loop_block, cond_out)
_add_output_to_block(loop_block, i_split)
loop_out = loop.node().output()

# In this loop, the outputs are scan outputs and are concatenated along
# the zero'th dimension (by default). In order to avoid this and concatenate
# along the dimension provided, some post-processing is required
loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
return reshape(g, loop_out, g.op("Constant", value_t=torch.LongTensor(output_sizes)))
6 changes: 5 additions & 1 deletion torch/onnx/symbolic_opset9.py
Expand Up @@ -1996,13 +1996,17 @@ def repeat_interleave(g, self, repeats, dim=None):
if not sym_help._is_tensor(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
if input_sizes[dim] == 0:
raise NotImplementedError("Unsupported repeat_interleave along dimension with unknown input size")
return sym_help._onnx_opset_unsupported_detailed('repeat_interleave', 9, 11,
'Unsupported along dimension with unknown input size')
else:
reps = input_sizes[dim]
repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None)

# Cases where repeats is a 1 dim Tensor
elif repeats_dim == 1:
if input_sizes[dim] == 0:
return sym_help._onnx_opset_unsupported_detailed('repeat_interleave', 9, 11,
'Unsupported along dimension with unknown input size')
assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim"
reps = repeats_sizes[0]
else:
Expand Down

0 comments on commit 7cd2fe4

Please sign in to comment.