Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Added support for constant folding onnx::Add and onnx::Sub #35869

Closed
wants to merge 1 commit into from

Conversation

@fortianyou
Copy link

@fortianyou fortianyou commented Apr 2, 2020

Added support for constant folding onnx::Add and onnx::Sub

@dr-ci
Copy link

@dr-ci dr-ci bot commented Apr 2, 2020

💊 CircleCI build failures summary and remediations

As of commit 8afbc09 (more details on the Dr. CI page):


None of the build failures appear to be your fault 💚


  • 2/2 broken upstream at merge base e56ba84 since Apr 05

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase FETCH_HEAD
    

    Check out the recency history of this "viable master" tracking branch.


🚧 2 upstream failures:

These were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 20 times.

Loading

bddppq
bddppq approved these changes Apr 5, 2020
Copy link
Contributor

@bddppq bddppq left a comment

LGTM

Loading

test/onnx/test_utility_funs.py Show resolved Hide resolved
Loading
@bddppq bddppq requested review from houseroad and spandantiwari Apr 5, 2020
@bddppq
Copy link
Contributor

@bddppq bddppq commented Apr 5, 2020

CI failures are unrelated

Loading

@bddppq
Copy link
Contributor

@bddppq bddppq commented Apr 5, 2020

@pytorchbot retest this please

Loading

Copy link

@yangjunpro yangjunpro left a comment

Just some small comments

Loading

operator_export_type=OperatorExportTypes.ONNX)
for node in graph.nodes():
assert node.kind() != "onnx::Add"
assert len(list(graph.nodes())) == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add a value check to verify the correctness of the constant folding behavior?

Loading

Copy link
Author

@fortianyou fortianyou Apr 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Loading

operator_export_type=OperatorExportTypes.ONNX)
for node in graph.nodes():
assert node.kind() != "onnx::Sub"
assert len(list(graph.nodes())) == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Loading

Copy link
Author

@fortianyou fortianyou Apr 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Loading

@@ -233,6 +233,12 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
} else if (node->kind() == onnx::Mul) {
updated_val = at::mul(inputTensorValues[0], inputTensorValues[1]);
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Sub) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A potential more generic approach is to abstract the Mul/Div/Sub/Add as binary op, the pseudo code may be as following:

if (node->kind() == onnx::Div || node->kind() == onnx::Mul ...) {
kind2func k2f = { {onnx::Mul: at::mul}, ... };
function f = k2f(node->kind());
return c10::optionalat::Tensor(f(inputTensorValues[0], inputTensorValues[1]));
}

Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion with Tianyou, I take back my comments since aten::add/sub has different argument number against aten::div/mul ( 3 against 2), although we may use bind tricks to unify the interface, it is a little bit overkill. I think Tianyou's current implementation is good enough.

Loading

test/onnx/test_utility_funs.py Show resolved Hide resolved
Loading
Copy link
Member

@houseroad houseroad left a comment

Overall looks good. Could you rebase to master and let the ONNX related CI run?

Loading

test/onnx/test_utility_funs.py Outdated Show resolved Hide resolved
Loading
@fortianyou fortianyou force-pushed the gty-dev/onnx-constfold branch from 2a2363e to 8afbc09 Apr 6, 2020
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Loading

@fortianyou
Copy link
Author

@fortianyou fortianyou commented Apr 6, 2020

Overall looks good. Could you rebase to master and let the ONNX related CI run?

Done

Loading

@facebook-github-bot
Copy link
Contributor

@facebook-github-bot facebook-github-bot commented Apr 6, 2020

@houseroad merged this pull request in 8dba98d.

Loading

ashishfarmer pushed a commit to ashishfarmer/pytorch that referenced this issue Apr 13, 2020
…torch#35869)

Summary:
Added support for constant folding onnx::Add and onnx::Sub
Pull Request resolved: pytorch#35869

Reviewed By: hl475

Differential Revision: D20865640

Pulled By: houseroad

fbshipit-source-id: 2b8c1cc196959b5b5b9ce018dbdcb74d59a92d9f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment