-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove hard ORT dependency, set default implementation to pure #95
Conversation
Previously we relied heavily on our ORT-based constant folding. With ORT removed, this no longer works. onnx-optimizer does a similar optimization, but using the python ORT version.
We now use the onnxsim package for constant folding, which is helpful when ORT is not installed (which means our CF doesn't work). The parent_pt_model argument to the ONNX importer is also removed. Rather than reading the parameter values from the pt models during onnx import, the parameters are now expected as inputs to the model
This argument enables computing values based on the node and it's connectors. This enables moving a lot of operators to use this decorators (TODO in later PR)
Previously, pure impls removed unused inputs using constant_folding.remove_node_and_computation. This behavior was incorrect when the unused connector's src was used elsewhere. To fix this, remove_node_and_computation now has an optional 'connector' argument that removes a connector. Further, python-based op implementations now automatically remove input connectors that are unused.
Codecov Report
@@ Coverage Diff @@
## master #95 +/- ##
==========================================
+ Coverage 60.07% 67.39% +7.32%
==========================================
Files 58 58
Lines 6870 6955 +85
==========================================
+ Hits 4127 4687 +560
+ Misses 2743 2268 -475
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some comments
return PureReshape.forward(node, state, sdfg) | ||
@python_pure_op_implementation(compute=dict( | ||
shape=lambda input, node: | ||
[prod(input.shape[:node.axis]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait till you see the next PR :)
|
||
|
||
@op_implementation(op="Shape", name="pure") | ||
class PureShape(ONNXForward): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hopefully we can get rid of all of those :)
daceml/onnx/onnx_importer.py
Outdated
|
||
name = clean_onnx_name(unclean_name) | ||
if unclean_name in self.inputs: | ||
# remove the tensor from inputs since this is a consant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# remove the tensor from inputs since this is a consant | |
# remove the tensor from inputs since this is a constant |
function will be the name of the op that is being replaced. | ||
|
||
The compute parameter enables you to compute a variable given the node and | ||
it's inputs/outputs. This variable will be namespaced when parsing the function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's inputs/outputs. This variable will be namespaced when parsing the function. | |
its inputs/outputs. This variable will be namespaced when parsing the function. |
dead_dataflow_elimination.DeadDataflowElimination() | ||
]).apply_pass(sdfg, {}) | ||
|
||
# remove dangling nodes, this can happen iwth non-transients |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# remove dangling nodes, this can happen iwth non-transients | |
# remove dangling nodes, this can happen with non-transients |
state.remove_edge(e) | ||
|
||
pass_pipeline.Pipeline([ | ||
dead_dataflow_elimination.DeadDataflowElimination() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting use. won't it remove much more than that single node's dependencies?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially, but in the NN sdfgs there's usually no dead code.
We can fix it later if we run into issues
This PR makes the ORT dependency soft. If ORT is not installed, ORT node expansion will fail.
There is a new CI branch that runs tests without ORT.
Since our constant folding depends on ORT, we add more cleanup to the ONNX Importer. It now uses
onnxsim
to do some preliminary cleanup, doing a lot of the heavy lifting that our CF would do.To keep tests from failing when ORT is missing, we also add some new pure op implementations. This includes improvements to the
python_pure_op_implementation
decorator to make these more compact. In a later PR, I will apply these improvements to other pure impls (I expect to be able to simplify at least 11 implementations).