From ec1a37efa20d52e87ad673934e0b40d6b15f10d0 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 29 Jan 2021 17:44:19 -0800 Subject: [PATCH 1/4] fix opset 11 ConstantChunk with negative dim --- test/onnx/test_pytorch_onnx_onnxruntime.py | 26 ++++++++++++++++++++-- torch/onnx/symbolic_opset11.py | 3 +-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 68bff71381937..e321829e85732 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -3794,14 +3794,17 @@ def forward(self, input): @disableScriptTest() def test_chunk(self): class ChunkModel(torch.nn.Module): - def __init__(self): + def __init__(self, dim=1): super(ChunkModel, self).__init__() + self.dim = dim def forward(self, x): - return torch.chunk(x, 3, dim=1) + return torch.chunk(x, 3, dim=self.dim) model = ChunkModel() model.eval() + model_neg_dim = ChunkModel(-1) + model_neg_dim.eval() x = torch.randn(1, 18) for dim_size_ in range(13, 16): @@ -3810,6 +3813,10 @@ def forward(self, x): input_names=['x'], dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}}) + self.run_test(model_neg_dim, x, test_with_inputs=[y], + input_names=['x'], + dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}}) + def test_concat(self): class ConcatModel(torch.nn.Module): def forward(self, x, y, z): @@ -5823,6 +5830,21 @@ def make_input(batch_size): other_input = make_input(RNN_BATCH_SIZE + 1) self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1) + @disableScriptTest() + def test_transformer_encoder(self): + from torch.nn import TransformerEncoderLayer, TransformerEncoder + class MyModule(torch.nn.Module): + def __init__(self, ninp, nhead, nhid, dropout, nlayers): + super(MyModule, self).__init__() + encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + + def forward(self, input): + return self.transformer_encoder(input) + + x = torch.rand(10, 32, 512) + self.run_test(MyModule(512, 8, 2048 , 0., 3), (x,)) + @skipIfUnsupportedMinOpsetVersion(10) def test_fake_quantize_per_tensor(self): class FakeQuantizePerTensorModel(torch.nn.Module): diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 3792f77ae3772..e595e2119ef41 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -863,8 +863,7 @@ def embedding_bag(g, def prim_ConstantChunk(g, self, chunks, dim): input_shape = g.op("Shape", self) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - axis_next = g.op("Constant", value_t=torch.tensor([dim + 1], dtype=torch.long)) - input_shape_dim = g.op("Slice", input_shape, axis, axis_next) + input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)) From 0475871774c8c3c656425bff38f0402978de054c Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 29 Jan 2021 18:09:47 -0800 Subject: [PATCH 2/4] fix flake8 test --- test/onnx/test_pytorch_onnx_onnxruntime.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index e321829e85732..5144b02348182 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -5833,6 +5833,7 @@ def make_input(batch_size): @disableScriptTest() def test_transformer_encoder(self): from torch.nn import TransformerEncoderLayer, TransformerEncoder + class MyModule(torch.nn.Module): def __init__(self, ninp, nhead, nhid, dropout, nlayers): super(MyModule, self).__init__() From c41bf171f6a32010f5192cf7486ef5b11d8a3725 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 1 Feb 2021 11:50:51 -0800 Subject: [PATCH 3/4] relax precision by a little --- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 5144b02348182..14b3ac5ba44eb 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -5844,7 +5844,7 @@ def forward(self, input): return self.transformer_encoder(input) x = torch.rand(10, 32, 512) - self.run_test(MyModule(512, 8, 2048 , 0., 3), (x,)) + self.run_test(MyModule(512, 8, 2048 , 0., 3), (x,), atol=1e-6) @skipIfUnsupportedMinOpsetVersion(10) def test_fake_quantize_per_tensor(self): From 6be1604279734966dbd8689f09f220209e103ceb Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 1 Feb 2021 14:36:22 -0800 Subject: [PATCH 4/4] add comment explaining skipped test --- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 14b3ac5ba44eb..29f563cf727ee 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -5830,7 +5830,7 @@ def make_input(batch_size): other_input = make_input(RNN_BATCH_SIZE + 1) self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1) - @disableScriptTest() + @disableScriptTest() # TODO: RuntimeError: Exporting the operator __is_ to ONNX is not supported def test_transformer_encoder(self): from torch.nn import TransformerEncoderLayer, TransformerEncoder