Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX]Update Dropout Export #37641

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .jenkins/caffe2/test.sh
Expand Up @@ -148,7 +148,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# default pip version is too old(9.0.2), unable to support tag `manylinux2010`.
# Fix the pip error: Couldn't find a version that satisfies the requirement
sudo pip install --upgrade pip
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202005041
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.3.0.dev202005121
fi
"$ROOT_DIR/scripts/onnx/test.sh"
fi
46 changes: 46 additions & 0 deletions test/onnx/expect/TestOperators.test_dropout_opset12.expect
@@ -0,0 +1,46 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.6"
graph {
node {
input: "x"
output: "1"
name: "ReduceMax_0"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 12
}
Expand Up @@ -15,18 +15,32 @@ graph {
type: TENSOR
}
}
node {
output: "2"
name: "Constant_1"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 9
raw_data: "\001"
}
type: TENSOR
}
}
node {
input: "x"
input: "1"
output: "2"
input: "2"
output: "3"
name: "Dropout_1"
output: "4"
name: "Dropout_2"
op_type: "Dropout"
}
node {
input: "2"
output: "4"
name: "ReduceMax_2"
input: "3"
output: "5"
name: "ReduceMax_3"
op_type: "ReduceMax"
attribute {
name: "keepdims"
Expand All @@ -52,7 +66,7 @@ graph {
}
}
output {
name: "4"
name: "5"
type {
tensor_type {
elem_type: 1
Expand Down
6 changes: 6 additions & 0 deletions test/onnx/test_operators.py
Expand Up @@ -681,6 +681,12 @@ def test_dropout_training(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, training=torch.onnx.TrainingMode.TRAINING)

@unittest.skip("disable test until onnx submodule is updated")
def test_dropout_opset12(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, opset_version=12)

@unittest.skip("disable test until onnx submodule is updated")
def test_dropout_training_opset12(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, opset_version=12, training=torch.onnx.TrainingMode.TRAINING)
Expand Down
13 changes: 13 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -3146,6 +3146,19 @@ def forward(self, mat1, mat2):
mat2 = torch.randn(3, 3)
self.run_test(M(), input=(mat1, mat2))

def test_dropout(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.dropout = torch.nn.Dropout(0.3)

def forward(self, x):
dropout = self.dropout(x)
return dropout

x = torch.randn(10, 3, 53)
self.run_test(M(), (x))

def test_shape_constant_fold(self):
class ShapeModule(torch.nn.Module):
def __init__(self):
Expand Down
47 changes: 44 additions & 3 deletions test/onnx/test_utility_funs.py
Expand Up @@ -5,7 +5,7 @@
import torch.onnx
from torch.onnx import utils, OperatorExportTypes
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
from test_pytorch_common import skipIfUnsupportedOpsetVersion
from test_pytorch_common import skipIfUnsupportedOpsetVersion, skipIfUnsupportedMinOpsetVersion

import onnx
import onnxruntime # noqa
Expand All @@ -14,6 +14,8 @@
import copy
import unittest

import numpy as np


skip = unittest.skip

Expand Down Expand Up @@ -524,8 +526,6 @@ def forward(self, x):
# verify that the model state is preserved
assert model.training == old_state

# TODO: Enable test when Dropout is implemented in ORT for opset 12.
@skipIfUnsupportedOpsetVersion([12])
def test_dropout_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand All @@ -549,6 +549,47 @@ def forward(self, x):
ort_outs = ort_sess.run(None, ort_inputs)
assert x != ort_outs[0]

@skipIfUnsupportedMinOpsetVersion(12)
def test_dropout_training_zero(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.dropout = torch.nn.Dropout(0.5)

def forward(self, x):
dropout = self.dropout(x)
return dropout

torch.manual_seed(0)
onnxruntime.set_seed(0)

model = MyModule()

# ensure there are no zeros in the input
x = torch.randn(10, 3, 128, 128)
y = x.numpy()
y_mask = np.where(y == 0, 1, y)
input = torch.from_numpy(y_mask)
nb_elements = torch.numel(input)

model.train()

f = io.BytesIO()
torch.onnx.export(model, (input,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name : input.cpu().numpy()}
ort_outs = ort_sess.run(None, ort_inputs)
y = model(input)
output = y.cpu().numpy()

ort_mask = np.where(ort_outs[0] != 0, 1, 0)
pyt_mask = np.where(output != 0, 1, 0)

ratio_pytorch = np.sum(pyt_mask) / nb_elements
ratio_ort = np.sum(ort_mask) / nb_elements

np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)

# opset 10 tests
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
Expand Down
4 changes: 3 additions & 1 deletion torch/onnx/symbolic_opset12.py
Expand Up @@ -26,8 +26,10 @@ def dropout(g, input, p, train):
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
if not sym_help._training_mode:
return input

p = g.op("Constant", value_t=torch.tensor(p))
r, _ = g.op("Dropout", input, p, outputs=2)
t = g.op("Constant", value_t=torch.tensor(True))
r, _ = g.op("Dropout", input, p, t, outputs=2)
return r


Expand Down