-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flop counter utility #95751
Add flop counter utility #95751
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95751
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d7256a4: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 477bdd302c406e1c4c7e85ab9eb67dc2c9ce5da2 Pull Request resolved: #95751
[ghstack-poisoned]
ghstack-source-id: 871593381439193148689ea823e158338cbfa154 Pull Request resolved: #95751
[ghstack-poisoned]
ghstack-source-id: a43cde0a4933ebd1a1839f57c2840d7da7b94d39 Pull Request resolved: #95751
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] [ghstack-poisoned]
ghstack-source-id: 59eb9a3290d67db39822b4d14d163e2b29f811cd Pull Request resolved: #95751
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] [ghstack-poisoned]
ghstack-source-id: eb8901eff886f9018a19fb5d5730d1139069064b Pull Request resolved: #95751
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function. [ghstack-poisoned]
ghstack-source-id: e17f2588d5a75c6b5081505eb16d09c7ae150411 Pull Request resolved: #95751
Who do you want to do a close code review on this? |
return flop_count | ||
|
||
|
||
flop_mapping = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add support to scaled_dot_product_attention since it's more and more used nowadays?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And einsum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Einsum isn’t needed since it’s a “compositeimplicit” op and decomposes into other operators. There’s an example of counting einsum flops in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm gonna do the attention ones in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually leads to an important design issue: I think there should be an API to return "unsupported" ops. Similar to https://github.com/facebookresearch/fvcore/blob/51092b5515cbb493f73de079743dd6b11cc4bbf1/fvcore/nn/jit_analysis.py#L98
The reason is that users are mainly interacting with high-level modules and ops, and aren't aware of what low-level ops are called. Imagine one day torch adds a new op for transformer, or adds a different implementation of einsum that doesn't decompose: the change should be transparent, but it makes the flop counter become a silent error!
So in fvcore we provide unsupported ops and print a warning about them by default. To make the results more meaningful, we have a list of trivial ops that are always ignored and will not appear in "unsupported ops" (https://github.com/facebookresearch/fvcore/blob/51092b5515cbb493f73de079743dd6b11cc4bbf1/fvcore/nn/jit_analysis.py#L28).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ppwwyyxx Hmm... I think I might just add all the ops as "unsupported ops".
I think one benefit of putting it in core is that in principle, we should hopefully be able to keep the operator list up to date better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, impl changes should be caught by tests, we have an einsum test in this PR, but not matmul?
func_packet = func._overloadpacket | ||
if func_packet in self.flop_mapping: | ||
flop_count_func = self.flop_mapping[func_packet] | ||
args_shape, out_shape = tree_map(get_shape, (args, normalize_tuple(out))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a good idea to always apply get_shape on tensors? For example I wonder if there would be an op that takes a scalar boolean tensor and condition two different behaviors (with different flops) on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a big deal - if that ever shows up we can just pass through scalar tensors.
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function. [ghstack-poisoned]
ghstack-source-id: 7d04ddd18e2ff6abe41e09d6e01d03e452864982 Pull Request resolved: #95751
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to upgrade the public API script again to pick these up...
aten.convolution_backward: conv_backward_flop, | ||
} | ||
|
||
def normalize_tuple(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document public API or make them private if they shouldn't be public.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah that's what you meant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can also just set __all__
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but then you'll see that you can't have these functions that look public and are not in __all__
. So you will have to prepend them with _
anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why will I see? It seems to work for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can try python test/test_public_bindings.py
, I would expect that it will fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to work for me 🤔
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function. [ghstack-poisoned]
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok if CI is ok!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
def forward(ctx, *args): | ||
assert(self.parents[-1] == name) | ||
self.parents.pop() | ||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick question: why are we calling clone()
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the output of an autograd.Function is a view of the input, there are some restrictions on whether you can use inplace operations or not.
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function. Pull Request resolved: pytorch/pytorch#95751 Approved by: https://github.com/ngimel, https://github.com/albanD
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function. Pull Request resolved: pytorch/pytorch#95751 Approved by: https://github.com/ngimel, https://github.com/albanD
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function. Pull Request resolved: pytorch#95751 Approved by: https://github.com/ngimel, https://github.com/albanD
Overall, an example usage. Note that this *also* captures backwards FLOPs. ``` import torchvision.models as models import torch from torch.utils.flop_counter import FlopCounterMode inp = torch.randn(1, 3, 224, 224, device='cpu') mod = models.resnet18() flop_counter = FlopCounterMode(mod, depth=1) with flop_counter: mod(inp).sum().backward() ``` <img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png"> You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs. <img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png"> ## Other APIs FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions FlopCounterMode.get_table(depth=...): Explicitly get the table as a string FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]] FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function. Pull Request resolved: pytorch#95751 Approved by: https://github.com/ngimel, https://github.com/albanD
Stack from ghstack (oldest at bottom):
Overall, an example usage. Note that this also captures backwards FLOPs.
You can control the depth of the module hierarchy with the
depth
attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.Other APIs
FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.