Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Added support for constant folding onnx::Add and onnx::Sub #35869

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 53 additions & 1 deletion test/onnx/test_utility_funs.py
Expand Up @@ -260,7 +260,7 @@ def forward(self, x):
assert node.kind() != "onnx::Concat"
assert node.kind() != "onnx::Cast"
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 2
assert len(list(graph.nodes())) == 1
bddppq marked this conversation as resolved.
Show resolved Hide resolved

# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
Expand Down Expand Up @@ -370,6 +370,58 @@ def forward(self, x):
assert node.kind() != "onnx::Mul"
assert len(list(graph.nodes())) == 1

# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_add(self):
class Module(torch.nn.Module):
def __init__(self, ):
super(Module, self).__init__()
self.register_buffer("weight", torch.ones(5))

def forward(self, x):
add = self.weight + torch.tensor([1, 2, 3, 4, 5])
return add - x

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
graph, params_dict, __ = utils._model_to_graph(
Module(), (x, ), do_constant_folding=True,
operator_export_type=OperatorExportTypes.ONNX)
for node in graph.nodes():
self.assertTrue(node.kind() != "onnx::Add")
self.assertEqual(len(list(graph.nodes())), 1)
params = list(params_dict.values())
self.assertEqual(len(params), 1)
weight = params[0]
self.assertEqual(weight, torch.tensor([2, 3, 4, 5, 6]))

# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_sub(self):
class Module(torch.nn.Module):
def __init__(self, ):
super(Module, self).__init__()
self.register_buffer("weight", torch.ones(5))

def forward(self, x):
sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
return sub + x

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
graph, params_dict, __ = utils._model_to_graph(
Module(), (x, ), do_constant_folding=True,
operator_export_type=OperatorExportTypes.ONNX)
for node in graph.nodes():
assert node.kind() != "onnx::Sub"
self.assertEqual(len(list(graph.nodes())), 1)
params = list(params_dict.values())
self.assertEqual(len(params), 1)
weight = params[0]
self.assertEqual(weight, torch.tensor([0, -1, -2, -3, -4]))

# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_sqrt(self):
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/passes/onnx/constant_fold.cpp
Expand Up @@ -233,6 +233,12 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
} else if (node->kind() == onnx::Mul) {
updated_val = at::mul(inputTensorValues[0], inputTensorValues[1]);
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Sub) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A potential more generic approach is to abstract the Mul/Div/Sub/Add as binary op, the pseudo code may be as following:

if (node->kind() == onnx::Div || node->kind() == onnx::Mul ...) {
kind2func k2f = { {onnx::Mul: at::mul}, ... };
function f = k2f(node->kind());
return c10::optionalat::Tensor(f(inputTensorValues[0], inputTensorValues[1]));
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion with Tianyou, I take back my comments since aten::add/sub has different argument number against aten::div/mul ( 3 against 2), although we may use bind tricks to unify the interface, it is a little bit overkill. I think Tianyou's current implementation is good enough.

updated_val = at::sub(inputTensorValues[0], inputTensorValues[1]);
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Add) {
updated_val = at::add(inputTensorValues[0], inputTensorValues[1]);
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Unsqueeze) {
assert(inputTensorValues.size() == 1);
if (!node->hasAttributeS("axes")) {
Expand Down