Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add summary function for PyG/PyTorch models #5859

Merged
merged 26 commits into from
Dec 6, 2022
Merged

Add summary function for PyG/PyTorch models #5859

merged 26 commits into from
Dec 6, 2022

Conversation

EdisonLeeeee
Copy link
Contributor

@EdisonLeeeee EdisonLeeeee commented Oct 31, 2022

As discussed in #5727, this PR is to add a summary function for all PyTorch particularly PyG defined models. Current implementation summarizes the information includes:

  • Layer names,
  • Input/output shapes,
  • Number of parameters,
  • Excutation time.

Usage

import torch
from torch_geometric.nn import summary
from torch_geometric.nn.models import GCN

model = GCN(128, 64, 2, out_channels=32)
x = torch.randn(100, 128)
edge_index = torch.randint(100, size=(2, 20))
print(summary(model, x, edge_index))

+-----------------------------------------+---------------------+----------------+----------+--------+
| Layer                                   | Input Shape         | Output Shape   | #Param   | Time   |
|-----------------------------------------+---------------------+----------------+----------+--------|
| GCN                                     | [100, 128], [2, 20] | [100, 32]      | 10,336   | 0.0008 |
| ├─(act)ReLU                             | [100, 64]           | [100, 64]      | --       | 0.0000 |
| ├─(convs)ModuleList                     | --                  | --             | 10,336   | --     |
| │    └─(0)GCNConv                       | [100, 128], [2, 20] | [100, 64]      | 8,256    | 0.0006 |
| │    │    └─(aggr_module)SumAggregation | [119, 64], [119]    | [100, 64]      | --       | 0.0000 |
| │    │    └─(lin)Linear                 | [100, 128]          | [100, 64]      | 8,192    | 0.0001 |
| │    └─(1)GCNConv                       | [100, 64], [2, 20]  | [100, 32]      | 2,080    | 0.0002 |
| │    │    └─(aggr_module)SumAggregation | [119, 32], [119]    | [100, 32]      | --       | 0.0000 |
| │    │    └─(lin)Linear                 | [100, 64]           | [100, 32]      | 2,048    | 0.0000 |
+-----------------------------------------+---------------------+----------------+----------+--------+

Set max_depth=1 for a better view

print(summary(model, x, edge_index, max_depth=1))

+---------------------+---------------------+----------------+----------+--------+
| Layer               | Input Shape         | Output Shape   | #Param   | Time   |
|---------------------+---------------------+----------------+----------+--------|
| GCN                 | [100, 128], [2, 20] | [100, 32]      | 10,336   | 0.0021 |
| ├─(act)ReLU         | [100, 64]           | [100, 64]      | --       | 0.0000 |
| ├─(convs)ModuleList | --                  | --             | 10,336   | --     |
+---------------------+---------------------+----------------+----------+--------+

Update

import torch
from torch_geometric.nn import summary
from torch_geometric.nn.models import GCN

model = GCN(128, 64, 2, out_channels=32)
x = torch.randn(100, 128)
edge_index = torch.randint(100, size=(2, 20))
print(summary(model, x, edge_index))

+---------------------+---------------------+----------------+----------+
| Layer               | Input Shape         | Output Shape   | #Param   |
|---------------------+---------------------+----------------+----------|
| GCN                 | [100, 128], [2, 20] | [100, 32]      | 10,336   |
| ├─(act)ReLU         | [100, 64]           | [100, 64]      | --       |
| ├─(convs)ModuleList | --                  | --             | 10,336   |
| │    └─(0)GCNConv   | [100, 128], [2, 20] | [100, 64]      | 8,256    |
| │    └─(1)GCNConv   | [100, 64], [2, 20]  | [100, 32]      | 2,080    |
+---------------------+---------------------+----------------+----------+

Any suggestion would be highly appreciated :)

@codecov
Copy link

codecov bot commented Oct 31, 2022

Codecov Report

Merging #5859 (70107d5) into master (82d8ad2) will increase coverage by 0.05%.
The diff coverage is 100.00%.

❗ Current head 70107d5 differs from pull request most recent head b875f01. Consider uploading reports for the commit b875f01 to get more accurate results

@@            Coverage Diff             @@
##           master    #5859      +/-   ##
==========================================
+ Coverage   84.36%   84.42%   +0.05%     
==========================================
  Files         365      366       +1     
  Lines       20510    20586      +76     
==========================================
+ Hits        17303    17379      +76     
  Misses       3207     3207              
Impacted Files Coverage Δ
torch_geometric/nn/__init__.py 100.00% <100.00%> (ø)
torch_geometric/nn/summary.py 100.00% <100.00%> (ø)
torch_geometric/nn/conv/gcn_conv.py 97.93% <0.00%> (-0.20%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

torch_geometric/nn/summary.py Show resolved Hide resolved
torch_geometric/nn/summary.py Outdated Show resolved Hide resolved
torch_geometric/nn/summary.py Outdated Show resolved Hide resolved
torch_geometric/nn/summary.py Show resolved Hide resolved
torch_geometric/nn/summary.py Show resolved Hide resolved
stack = [(get_name(model), model, depth)]

info_list = []
input_shape = defaultdict(list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason these need to be lists?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems weird but it makes sense for cases with some modules reused. For example:

relu = torch.nn.ReLU()
model = torch.nn.Sequential(torch.Linear(10, 10), relu, torch.nn.Linear(10, 10), relu)

where relu module would be called twice during forward but we only registered one hook for this module. Therefore, a list is used for sharing the input/output information for info_list[0] (first relu) and info_list[3] (second relu). That also gets tricky but currently I don't have a better idea to tackle this case. WDYT?

info = {}
info['name'] = var_name
info['layer'] = module
info['input_shape'] = input_shape[id(module)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also set these in the forward hook, such that this loop only registers the hook.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my first attempt, but the problem is that some of the modules (e.g. torch.nn.ModuleList) don't have a forward method implemented and would be ignored then.

torch_geometric/nn/summary.py Outdated Show resolved Hide resolved
torch_geometric/nn/summary.py Outdated Show resolved Hide resolved
torch_geometric/nn/summary.py Show resolved Hide resolved
test/nn/test_model_summary.py Outdated Show resolved Hide resolved
torch_geometric/nn/summary.py Outdated Show resolved Hide resolved
torch_geometric/nn/summary.py Show resolved Hide resolved
torch_geometric/nn/summary.py Outdated Show resolved Hide resolved
test/nn/test_model_summary.py Outdated Show resolved Hide resolved
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is amazing, thank you! Sorry for taking me so long to have a final look.

@rusty1s rusty1s enabled auto-merge (squash) December 6, 2022 13:38
@rusty1s rusty1s merged commit de073c7 into master Dec 6, 2022
@rusty1s rusty1s deleted the summary branch December 6, 2022 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants