-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add summary
function for PyG/PyTorch models
#5859
Conversation
Codecov Report
@@ Coverage Diff @@
## master #5859 +/- ##
==========================================
+ Coverage 84.36% 84.42% +0.05%
==========================================
Files 365 366 +1
Lines 20510 20586 +76
==========================================
+ Hits 17303 17379 +76
Misses 3207 3207
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
stack = [(get_name(model), model, depth)] | ||
|
||
info_list = [] | ||
input_shape = defaultdict(list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason these need to be lists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems weird but it makes sense for cases with some modules reused. For example:
relu = torch.nn.ReLU()
model = torch.nn.Sequential(torch.Linear(10, 10), relu, torch.nn.Linear(10, 10), relu)
where relu
module would be called twice during forward but we only registered one hook for this module. Therefore, a list is used for sharing the input/output information for info_list[0] (first relu)
and info_list[3] (second relu)
. That also gets tricky but currently I don't have a better idea to tackle this case. WDYT?
torch_geometric/nn/summary.py
Outdated
info = {} | ||
info['name'] = var_name | ||
info['layer'] = module | ||
info['input_shape'] = input_shape[id(module)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also set these in the forward hook, such that this loop only registers the hook.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my first attempt, but the problem is that some of the modules (e.g. torch.nn.ModuleList
) don't have a forward method implemented and would be ignored then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is amazing, thank you! Sorry for taking me so long to have a final look.
As discussed in #5727, this PR is to add a summary function for all PyTorch particularly PyG defined models. Current implementation summarizes the information includes:
Usage
Set
max_depth=1
for a better viewUpdate
Any suggestion would be highly appreciated :)