Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Fix opset 11 ConstantChunk with negative dim #51396

Merged
merged 4 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 25 additions & 2 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3794,14 +3794,17 @@ def forward(self, input):
@disableScriptTest()
def test_chunk(self):
class ChunkModel(torch.nn.Module):
def __init__(self):
def __init__(self, dim=1):
super(ChunkModel, self).__init__()
self.dim = dim

def forward(self, x):
return torch.chunk(x, 3, dim=1)
return torch.chunk(x, 3, dim=self.dim)

model = ChunkModel()
model.eval()
model_neg_dim = ChunkModel(-1)
model_neg_dim.eval()
x = torch.randn(1, 18)

for dim_size_ in range(13, 16):
Expand All @@ -3810,6 +3813,10 @@ def forward(self, x):
input_names=['x'],
dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}})

self.run_test(model_neg_dim, x, test_with_inputs=[y],
input_names=['x'],
dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}})

def test_concat(self):
class ConcatModel(torch.nn.Module):
def forward(self, x, y, z):
Expand Down Expand Up @@ -5823,6 +5830,22 @@ def make_input(batch_size):
other_input = make_input(RNN_BATCH_SIZE + 1)
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)

@disableScriptTest() # TODO: RuntimeError: Exporting the operator __is_ to ONNX is not supported
def test_transformer_encoder(self):
from torch.nn import TransformerEncoderLayer, TransformerEncoder

class MyModule(torch.nn.Module):
def __init__(self, ninp, nhead, nhid, dropout, nlayers):
super(MyModule, self).__init__()
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

def forward(self, input):
return self.transformer_encoder(input)

x = torch.rand(10, 32, 512)
self.run_test(MyModule(512, 8, 2048 , 0., 3), (x,), atol=1e-6)

@skipIfUnsupportedMinOpsetVersion(10)
def test_fake_quantize_per_tensor(self):
class FakeQuantizePerTensorModel(torch.nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,7 @@ def embedding_bag(g,
def prim_ConstantChunk(g, self, chunks, dim):
input_shape = g.op("Shape", self)
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
axis_next = g.op("Constant", value_t=torch.tensor([dim + 1], dtype=torch.long))
input_shape_dim = g.op("Slice", input_shape, axis, axis_next)
input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))
Expand Down