Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid a graph break in ModuleDict and ParameterDict #8363

Merged
merged 22 commits into from
Nov 14, 2023

Conversation

akihironitta
Copy link
Member

@akihironitta akihironitta commented Nov 10, 2023

We have a graph break at hasattr call in ModuleDict.to_internal_key.

Repro:

import torch
from torch_geometric.nn.module_dict import ModuleDict

edge_type = ("a", "to", "b")

class SomeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.module_dict = ModuleDict({
            edge_type: torch.nn.Linear(1, 1),
        })

    def forward(self, x):
        # need to convert tuple to string in advance to avoid a graph break
        # due to https://github.com/pytorch/pytorch/issues/111551
        key = ModuleDict.to_internal_key(edge_type)
        x = self.module_dict[key](x)
        return x

from torch._dynamo.utils import CompileProfiler
model = SomeModel()
with CompileProfiler() as prof:
    model = torch.compile(model)
    model(torch.randn(1, 1))
    print(prof.report())

@akihironitta akihironitta changed the title [Draft] Avoid a graph break in torc_geometric.nn.module_dict.ModuleDict Avoid a graph break in torc_geometric.nn.module_dict.ModuleDict Nov 13, 2023
@akihironitta akihironitta marked this pull request as ready for review November 13, 2023 03:52
Copy link

codecov bot commented Nov 13, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (b1f8535) 88.80% compared to head (171e494) 88.41%.

❗ Current head 171e494 differs from pull request most recent head c2f1378. Consider uploading reports for the commit c2f1378 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #8363      +/-   ##
==========================================
- Coverage   88.80%   88.41%   -0.39%     
==========================================
  Files         475      475              
  Lines       28841    28838       -3     
==========================================
- Hits        25611    25497     -114     
- Misses       3230     3341     +111     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@akihironitta akihironitta marked this pull request as draft November 13, 2023 11:14
@akihironitta
Copy link
Member Author

Converting this to draft as it's currently not a proper fix.

@rusty1s rusty1s marked this pull request as ready for review November 13, 2023 15:08
@akihironitta akihironitta changed the title Avoid a graph break in torc_geometric.nn.module_dict.ModuleDict Avoid a graph break in ModuleDict Nov 14, 2023
@rusty1s rusty1s changed the title Avoid a graph break in ModuleDict Avoid a graph break in ModuleDict and ParameterDict Nov 14, 2023
@rusty1s rusty1s enabled auto-merge (squash) November 14, 2023 07:30
@rusty1s rusty1s merged commit ccbbbdd into master Nov 14, 2023
14 checks passed
@rusty1s rusty1s deleted the akihironitta-patch-1 branch November 14, 2023 07:36
assert len(key) > 1
key = f"<{'___'.join(key)}>"
assert isinstance(key, str)

# ModuleDict cannot handle keys that exists as class attributes:
if hasattr(cls, key):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the note, graph breaks produeced by hasattr here were likely due to pytorch/pytorch#111522.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants