Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Handle dynamic input axes for prim_ConstantChunk #48176

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f168d8a
Reimplement var_mean
jiafatom Nov 2, 2020
57a633c
Merge branch 'master' into var_mean
jiafatom Nov 2, 2020
8b00db5
Merge branch 'var_mean' of https://github.com/jiafatom/pytorch
jiafatom Nov 3, 2020
44a3477
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 3, 2020
6296050
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 3, 2020
95e0231
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 6, 2020
e26ff89
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 9, 2020
c8b9bee
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 18, 2020
3752fd0
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 18, 2020
8d8a83d
Handle dynamic input axes for prim_ConstantChunk
jiafatom Nov 18, 2020
c883103
Handle dynamic input axes for prim_ConstantChunk
jiafatom Nov 18, 2020
48233b2
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 18, 2020
ac0da47
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 18, 2020
909f4d9
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 18, 2020
ddf6038
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 18, 2020
73a3f58
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 23, 2020
df628ab
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 23, 2020
51e0bb8
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 23, 2020
adcf6bb
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 30, 2020
112e2ab
Merge branch 'master' of https://github.com/pytorch/pytorch
jiafatom Nov 30, 2020
1929925
Merge branch 'split_un' of https://github.com/jiafatom/pytorch into s…
jiafatom Nov 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,6 +3240,25 @@ def forward(self, input):
x = torch.randn(5, 4, 3)
self.run_test(SplitModel2(), x)

@skipIfUnsupportedMinOpsetVersion(11)
def test_chunk(self):
class ChunkModel(torch.nn.Module):
def __init__(self):
super(ChunkModel, self).__init__()

def forward(self, x):
return torch.chunk(x, 3, dim=1)

model = ChunkModel()
model.eval()
x = torch.randn(1, 18)

for dim_size_ in range(13, 16):
y = torch.randn(1, dim_size_)
jiafatom marked this conversation as resolved.
Show resolved Hide resolved
self.run_test(model, x, test_with_inputs=[y],
input_names=['x'],
dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}})

def test_concat(self):
class ConcatModel(torch.nn.Module):
def forward(self, x, y, z):
Expand Down
19 changes: 19 additions & 0 deletions torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,3 +944,22 @@ def embedding_bag(g,
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return loop.node().output(), None, None, None


def prim_ConstantChunk(g, self, chunks, dim):
jiafatom marked this conversation as resolved.
Show resolved Hide resolved
input_shape = g.op("Shape", self)
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
axis_next = g.op("Constant", value_t=torch.tensor([dim + 1], dtype=torch.long))
input_shape_dim = g.op("Slice", input_shape, axis, axis_next)
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))
input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
res = []
for i in range(chunks):
index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
end = g.op("Mul", chunk_dim, index)
res.append(g.op("Slice", self, start, end, axis))
jiafatom marked this conversation as resolved.
Show resolved Hide resolved
start = end
return res