Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support symbolic tracing of SchNet #5938

Merged
merged 7 commits into from Nov 10, 2022
Merged

Conversation

hatemhelal
Copy link
Contributor

@hatemhelal hatemhelal commented Nov 9, 2022

This PR removes/modifies a couple control flow patterns that were used in the SchNet implementation that are not traceable with torch.fx:

  • removing the assertion in the forward pass
  • Replaced calling scatter directly with an Aggregation module

Without these changes, attempting symbolic tracing:

from torch_geometric.nn.fx import Transformer
from torch_geometric.nn import SchNet

model = SchNet()
tx = Transformer(model, debug=True)
tx.transform()

fails with:

TraceError: symbolically traced variables cannot be used as inputs to control flow

@codecov
Copy link

codecov bot commented Nov 9, 2022

Codecov Report

Merging #5938 (235523c) into master (1909dc6) will not change coverage.
The diff coverage is 100.00%.

❗ Current head 235523c differs from pull request most recent head 6de48d7. Consider uploading reports for the commit 6de48d7 to get more accurate results

@@           Coverage Diff           @@
##           master    #5938   +/-   ##
=======================================
  Coverage   84.48%   84.48%           
=======================================
  Files         358      358           
  Lines       19721    19721           
=======================================
  Hits        16662    16662           
  Misses       3059     3059           
Impacted Files Coverage Δ
torch_geometric/nn/models/schnet.py 70.42% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Contributor

@lightaime lightaime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM.

torch_geometric/nn/models/schnet.py Show resolved Hide resolved
Copy link
Contributor

@lightaime lightaime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thanks!

@rusty1s rusty1s added the feature label Nov 9, 2022
@rusty1s rusty1s changed the title Support symbolic tracing of SchNet Support symbolic tracing of SchNet Nov 10, 2022
@rusty1s rusty1s enabled auto-merge (squash) November 10, 2022 07:40
@rusty1s rusty1s merged commit 06a995e into pyg-team:master Nov 10, 2022
JakubPietrakIntel pushed a commit to JakubPietrakIntel/pytorch_geometric that referenced this pull request Nov 25, 2022
This PR removes/modifies a couple control flow patterns that were used
in the `SchNet` implementation that are not traceable with `torch.fx`:

* removing the assertion in the forward pass
* Replaced calling `scatter` directly with an `Aggregation` module

Without these changes, attempting symbolic tracing:
```
from torch_geometric.nn.fx import Transformer
from torch_geometric.nn import SchNet

model = SchNet()
tx = Transformer(model, debug=True)
tx.transform()
```

fails with:
```
TraceError: symbolically traced variables cannot be used as inputs to control flow
```

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants