Skip to content

Commit

Permalink
[ONNX] Enable Constant Folding for ONNX Opset 13 (#51096) (#51523)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #51523

* Enable Constant Folding for ONNX Opset 13

* fix CI clang-diagnostic

* fix integers type

* fix comments:sort axes and support negative number

* update squeeze op constant folding

* fix format warning

* fix clang-format issue

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26203111

Pulled By: SplitInfinity

fbshipit-source-id: c33637ab39db614207bd442c6ab464bd09339b4a

Co-authored-by: hwangdeyu <deyhuang@qq.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Feb 4, 2021
1 parent 1c7d966 commit 8ae6b0c
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 30 deletions.
88 changes: 79 additions & 9 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_constant_fold_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.unsqueeze(a, 0)
b = torch.unsqueeze(a, -2)
return b + x

_set_opset_version(self.opset_version)
Expand All @@ -225,7 +225,62 @@ def forward(self, x):
graph, _, __ = self._model_to_graph(UnsqueezeModule(), (x, ))

for node in graph.nodes():
assert node.kind() != "onnx::Unsqueeeze"
assert node.kind() != "onnx::Unsqueeze"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1

def test_constant_fold_unsqueeze_multi_axies(self):
class PReluModel(torch.nn.Module):
def __init__(self):
super(PReluModel, self).__init__()
self.prelu = torch.nn.PReLU()

def forward(self, x):
a = torch.randn(2, 3, 4, 5, 8, 7)
return self.prelu(x) + a

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
x = torch.randn(2, 3, 4, 5, 8, 7)
graph, _, __ = self._model_to_graph(PReluModel(), x)

for node in graph.nodes():
assert node.kind() != "onnx::Unsqueeze"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 4

def test_constant_fold_squeeze_without_axes(self):
class SqueezeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
return torch.squeeze(a) + x + torch.squeeze(a)

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(SqueezeModule(), (x, ))
print(graph)
for node in graph.nodes():
assert node.kind() != "onnx::Squeeze"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 2

def test_constant_fold_squeeze_with_axes(self):
class SqueezeAxesModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
return torch.squeeze(a, dim=-3) + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(SqueezeAxesModule(), (x, ))

for node in graph.nodes():
assert node.kind() != "onnx::Squeeze"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
Expand Down Expand Up @@ -284,7 +339,12 @@ def forward(self, input, initial_state):
assert node.kind() != "onnx::Slice"
assert node.kind() != "onnx::Concat"
assert node.kind() != "onnx::Unsqueeze"
assert len(list(graph.nodes())) == 3

if self.opset_version <= 12:
assert len(list(graph.nodes())) == 3
else:
# Unsqueeze op parameter 'axes' as an input instead of as an attribute when opset version >= 13
assert len(list(graph.nodes())) == 4

def test_constant_fold_transpose_matmul(self):
class MatMulNet(torch.nn.Module):
Expand Down Expand Up @@ -619,7 +679,7 @@ def forward(self, x):
assert next(iter).kind() == "aten::dequantize"

# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
@skipIfUnsupportedOpsetVersion([11, 12])
@skipIfUnsupportedOpsetVersion([11, 12, 13])
def test_prim_fallthrough(self):
# Test prim op
class PrimModule(torch.jit.ScriptModule):
Expand Down Expand Up @@ -781,7 +841,6 @@ def forward(self, x, y):
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=10))


# opset 11 tests
TestUtilityFuns_opset11 = type(str("TestUtilityFuns_opset11"),
(TestCase,),
Expand All @@ -792,18 +851,29 @@ def forward(self, x, y):
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=12))

# opset 13 tests
TestUtilityFuns_opset13 = type(str("TestUtilityFuns_opset13"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=13))

# opset 11 tests
TestUtilityFuns_opset9_new_jit_API = type(str("TestUtilityFuns_opset9_new_jit_API"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=9,
use_new_jit_passes=True))
TestUtilityFuns_opset11_new_jit_API = type(str("TestUtilityFuns_opset11_new_jit_API"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=11,
use_new_jit_passes=True))

# opset 12 tests
TestUtilityFuns_opset12_new_jit_API = type(str("TestUtilityFuns_opset12_new_jit_API"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=12,
use_new_jit_passes=True))

# opset 13 tests
TestUtilityFuns_opset13_new_jit_API = type(str("TestUtilityFuns_opset13_new_jit_API"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=13,
use_new_jit_passes=True))


if __name__ == '__main__':
run_tests()
116 changes: 96 additions & 20 deletions torch/csrc/jit/passes/onnx/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ c10::optional<at::Tensor> runTorchSlice_opset10(
if (inputTensorValues.size() < minSliceInputCount ||
inputTensorValues.size() > maxSliceInputCount) {
std::cerr
<< "Warning: Constant folding - Invalid number of inputs found for opset 10 or 11 onnx::Slice op. "
<< "Warning: Constant folding - Invalid number of inputs found for opset opset >= 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
// Checking validity of 'starts' and 'ends' input
if (inputTensorValues[1].sizes().size() != 1 ||
inputTensorValues[2].sizes().size() != 1) {
std::cerr
<< "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset 10 or 11 onnx::Slice op. "
<< "Warning: Constant folding - Invalid 'starts' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
Expand All @@ -131,15 +131,15 @@ c10::optional<at::Tensor> runTorchSlice_opset10(
if (inputTensorValues.size() > 3) {
if (inputTensorValues[3].sizes().size() != 1) {
std::cerr
<< "Warning: Constant folding - Invalid 'axes' input found for opset 10 onnx::Slice op. "
<< "Warning: Constant folding - Invalid 'axes' input found for opset >= 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
if (inputTensorValues[3].sizes()[0] != inputTensorValues[1].sizes()[0]) {
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the
// same
std::cerr
<< "Warning: Constant folding - Invalid 'axes' or 'ends' inputs found for opset 10 onnx::Slice op. "
<< "Warning: Constant folding - Invalid 'axes' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
Expand All @@ -157,15 +157,15 @@ c10::optional<at::Tensor> runTorchSlice_opset10(
if (inputTensorValues.size() > 4) {
if (inputTensorValues[4].sizes().size() != 1) {
std::cerr
<< "Warning: Constant folding - Invalid 'steps' input found for opset 10 onnx::Slice op. "
<< "Warning: Constant folding - Invalid 'steps' input found for opset >= 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
if (inputTensorValues[4].sizes()[0] != inputTensorValues[1].sizes()[0]) {
// Number of elements of 'steps' and 'ends' 1-D input tensors should be
// the same
std::cerr
<< "Warning: Constant folding - Invalid 'steps' or 'ends' inputs found for opset 10 onnx::Slice op. "
<< "Warning: Constant folding - Invalid 'steps' or 'ends' inputs found for opset >= 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
Expand All @@ -174,7 +174,7 @@ c10::optional<at::Tensor> runTorchSlice_opset10(
// Only steps == 1 are supported for constant-folding.
if (steps_a[i] != 1) {
std::cerr
<< "Warning: Constant folding - Only steps=1 can be constant folded for opset 10 onnx::Slice op. "
<< "Warning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
Expand Down Expand Up @@ -205,7 +205,7 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
return runTorchSlice_opset9(node, inputTensorValues);
} else if (
opset_version == ONNX_OPSET_10 || opset_version == ONNX_OPSET_11 ||
opset_version == ONNX_OPSET_12) {
opset_version == ONNX_OPSET_12 || opset_version == ONNX_OPSET_13) {
return runTorchSlice_opset10(node, inputTensorValues);
} else {
std::cerr << "Warning: Constant folding - unsupported opset version. "
Expand Down Expand Up @@ -235,15 +235,92 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
updated_val = at::add(inputTensorValues[0], inputTensorValues[1]);
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Unsqueeze) {
assert(inputTensorValues.size() == 1);
if (!node->hasAttributeS("axes")) {
if (opset_version >= ONNX_OPSET_13) {
assert(inputTensorValues.size() == 2);
// Checking validity of 'axes' input
if (inputTensorValues[1].sizes().size() != 1) {
std::cerr
<< "Warning: Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Unsqueeze op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
auto axes_a = inputTensorValues[1].accessor<int64_t, 1>();
std::vector<int64_t> axes;
for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
// ONNX unsqueeze accepts negative axes
axes_a[i] += axes_a[i] < 0 ? inputTensorValues[0].sizes().size() : 0;
axes.push_back(axes_a[i]);
}
std::sort(axes.begin(), axes.end());
updated_val = inputTensorValues[0];
for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
updated_val = at::unsqueeze(updated_val, axes[i]);
}
return c10::optional<at::Tensor>(updated_val);
} else if (
opset_version == ONNX_OPSET_9 || opset_version == ONNX_OPSET_10 ||
opset_version == ONNX_OPSET_11 || opset_version == ONNX_OPSET_12) {
assert(inputTensorValues.size() == 1);
if (!node->hasAttributeS("axes")) {
return c10::nullopt;
}
updated_val = inputTensorValues[0];
std::vector<int64_t> axesAttr = node->is(attr::axes);
std::sort(axesAttr.begin(), axesAttr.end());
for (auto axis : axesAttr) {
updated_val = at::unsqueeze(updated_val, axis);
}
return c10::optional<at::Tensor>(updated_val);
} else {
std::cerr << "Warning: Constant folding - unsupported opset version. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
updated_val = inputTensorValues[0];
for (auto axis : node->is(attr::axes)) {
updated_val = at::unsqueeze(updated_val, axis);
} else if (node->kind() == onnx::Squeeze) {
assert(inputTensorValues.size() == 2 or inputTensorValues.size() == 1);
if (opset_version == ONNX_OPSET_13) {
// Squeeze version 13 input axes is optional, inputTensorValues.size() ==
// 1 means axes equal to None
updated_val = inputTensorValues[0];
if (inputTensorValues.size() == 2) {
// Checking validity of 'axes' input
if (inputTensorValues[1].sizes().size() != 1) {
std::cerr
<< "Warning: Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Squeeze op. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
auto axes_a = inputTensorValues[1].accessor<int64_t, 1>();
std::vector<int64_t> axes;
for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
// ONNX Squeeze accepts negative axes
axes_a[i] += axes_a[i] < 0 ? inputTensorValues[0].sizes().size() : 0;
axes.push_back(axes_a[i]);
}
std::sort(axes.begin(), axes.end());
for (int64_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
updated_val = at::squeeze(updated_val, axes[i]);
}
}
return c10::optional<at::Tensor>(updated_val);
} else if (
opset_version == ONNX_OPSET_9 || opset_version == ONNX_OPSET_10 ||
opset_version == ONNX_OPSET_11 || opset_version == ONNX_OPSET_12) {
assert(inputTensorValues.size() == 1);
updated_val = inputTensorValues[0];
if (node->hasAttributeS("axes")) {
std::vector<int64_t> axesAttr = node->is(attr::axes);
std::sort(axesAttr.begin(), axesAttr.end());
for (auto axis : axesAttr) {
updated_val = at::squeeze(updated_val, axis);
}
}
return c10::optional<at::Tensor>(updated_val);
} else {
std::cerr << "Warning: Constant folding - unsupported opset version. "
<< "Constant folding not applied." << std::endl;
return c10::nullopt;
}
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Transpose) {
assert(inputTensorValues.size() == 1);
if (!node->hasAttributeS("perm")) {
Expand All @@ -264,7 +341,8 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
updated_val = inputTensorValues[0];
std::vector<int64_t> shape(inputTensorValues[1].sizes()[0], 0);
auto shape_a = inputTensorValues[1].accessor<int64_t, 1>();
for (size_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
assert(inputTensorValues[1].sizes()[0] >= 0);
for (size_t i = 0; i < (size_t)(inputTensorValues[1].sizes()[0]); ++i) {
// All shape dim values should be >= -1
// onnx::Reshape supports a shape dim value to be zero, in
// which case the actual dim value remains unchanged. However,
Expand Down Expand Up @@ -398,13 +476,11 @@ std::vector<Node*> getOnnxConstParentsToRemove(Node* node) {
// nodes can be lifted so we run them earlier, before the usual parameters are
// known.
void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
if (opset_version != ONNX_OPSET_9 && opset_version != ONNX_OPSET_10 &&
opset_version != ONNX_OPSET_11 && opset_version != ONNX_OPSET_12) {
if (opset_version < ONNX_OPSET_9) {
// Number of elements of 'axes' and 'ends' 1-D input tensors should be the
// same
std::cerr
<< "Warning: Constant folding supported for only opsets 9, 10, and 11. "
<< "Constant folding not applied." << std::endl;
std::cerr << "Warning: Constant folding supported for only opsets >= 9. "
<< "Constant folding not applied." << std::endl;
return;
}
AT_ASSERT(b->param_node());
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/onnx/constant_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const int ONNX_OPSET_9 = 9;
const int ONNX_OPSET_10 = 10;
const int ONNX_OPSET_11 = 11;
const int ONNX_OPSET_12 = 12;
const int ONNX_OPSET_13 = 13;
void ConstantFoldONNX(
Block* b,
std::map<std::string, IValue>& paramDict,
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ir_version = _C._onnx.IR_VERSION
producer_name = "pytorch"
producer_version = _C._onnx.PRODUCER_VERSION
constant_folding_opset_versions = [9, 10, 11, 12]
constant_folding_opset_versions = [9, 10, 11, 12, 13]


class ExportTypes:
Expand Down

0 comments on commit 8ae6b0c

Please sign in to comment.