New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support symbolic tracing of SchNet
#5938
Conversation
Codecov Report
@@ Coverage Diff @@
## master #5938 +/- ##
=======================================
Coverage 84.48% 84.48%
=======================================
Files 358 358
Lines 19721 19721
=======================================
Hits 16662 16662
Misses 3059 3059
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Thanks!
SchNet
This PR removes/modifies a couple control flow patterns that were used in the `SchNet` implementation that are not traceable with `torch.fx`: * removing the assertion in the forward pass * Replaced calling `scatter` directly with an `Aggregation` module Without these changes, attempting symbolic tracing: ``` from torch_geometric.nn.fx import Transformer from torch_geometric.nn import SchNet model = SchNet() tx = Transformer(model, debug=True) tx.transform() ``` fails with: ``` TraceError: symbolically traced variables cannot be used as inputs to control flow ``` Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
This PR removes/modifies a couple control flow patterns that were used in the
SchNet
implementation that are not traceable withtorch.fx
:scatter
directly with anAggregation
moduleWithout these changes, attempting symbolic tracing:
fails with: