Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flop counter utility #95751

Closed
wants to merge 9 commits into from
Closed

Conversation

Chillee
Copy link
Contributor

@Chillee Chillee commented Mar 1, 2023

Stack from ghstack (oldest at bottom):

Overall, an example usage. Note that this also captures backwards FLOPs.

import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()

image

You can control the depth of the module hierarchy with the depth attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

image

Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95751

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d7256a4:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Chillee added a commit that referenced this pull request Mar 1, 2023
ghstack-source-id: 477bdd302c406e1c4c7e85ab9eb67dc2c9ce5da2
Pull Request resolved: #95751
Chillee added a commit that referenced this pull request Mar 1, 2023
ghstack-source-id: 871593381439193148689ea823e158338cbfa154
Pull Request resolved: #95751
Chillee added a commit that referenced this pull request Mar 1, 2023
ghstack-source-id: a43cde0a4933ebd1a1839f57c2840d7da7b94d39
Pull Request resolved: #95751
@Chillee Chillee added the topic: new features topic category label Mar 1, 2023
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]


[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 1, 2023
ghstack-source-id: 59eb9a3290d67db39822b4d14d163e2b29f811cd
Pull Request resolved: #95751
@Chillee Chillee requested a review from ngimel March 1, 2023 01:53
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]


[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 1, 2023
ghstack-source-id: eb8901eff886f9018a19fb5d5730d1139069064b
Pull Request resolved: #95751
@Chillee Chillee added the release notes: composability release notes category label Mar 1, 2023
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.


[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 1, 2023
ghstack-source-id: e17f2588d5a75c6b5081505eb16d09c7ae150411
Pull Request resolved: #95751
@ezyang
Copy link
Contributor

ezyang commented Mar 1, 2023

Who do you want to do a close code review on this?

return flop_count


flop_mapping = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add support to scaled_dot_product_attention since it's more and more used nowadays?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And einsum

Copy link
Contributor Author

@Chillee Chillee Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Einsum isn’t needed since it’s a “compositeimplicit” op and decomposes into other operators. There’s an example of counting einsum flops in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna do the attention ones in a follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually leads to an important design issue: I think there should be an API to return "unsupported" ops. Similar to https://github.com/facebookresearch/fvcore/blob/51092b5515cbb493f73de079743dd6b11cc4bbf1/fvcore/nn/jit_analysis.py#L98

The reason is that users are mainly interacting with high-level modules and ops, and aren't aware of what low-level ops are called. Imagine one day torch adds a new op for transformer, or adds a different implementation of einsum that doesn't decompose: the change should be transparent, but it makes the flop counter become a silent error!

So in fvcore we provide unsupported ops and print a warning about them by default. To make the results more meaningful, we have a list of trivial ops that are always ignored and will not appear in "unsupported ops" (https://github.com/facebookresearch/fvcore/blob/51092b5515cbb493f73de079743dd6b11cc4bbf1/fvcore/nn/jit_analysis.py#L28).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ppwwyyxx Hmm... I think I might just add all the ops as "unsupported ops".

I think one benefit of putting it in core is that in principle, we should hopefully be able to keep the operator list up to date better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, impl changes should be caught by tests, we have an einsum test in this PR, but not matmul?

@Chillee
Copy link
Contributor Author

Chillee commented Mar 1, 2023

@ezyang Not sure I really need a close code review from anybody (in particular) 🤔 Maybe @ngimel

func_packet = func._overloadpacket
if func_packet in self.flop_mapping:
flop_count_func = self.flop_mapping[func_packet]
args_shape, out_shape = tree_map(get_shape, (args, normalize_tuple(out)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a good idea to always apply get_shape on tensors? For example I wonder if there would be an op that takes a scalar boolean tensor and condition two different behaviors (with different flops) on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a big deal - if that ever shows up we can just pass through scalar tensors.

Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.


[ghstack-poisoned]
Chillee added a commit that referenced this pull request Mar 2, 2023
ghstack-source-id: 7d04ddd18e2ff6abe41e09d6e01d03e452864982
Pull Request resolved: #95751
@Chillee Chillee requested review from ngimel and albanD March 2, 2023 02:48
@Chillee
Copy link
Contributor Author

Chillee commented Mar 2, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 2, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to upgrade the public API script again to pick these up...

aten.convolution_backward: conv_backward_flop,
}

def normalize_tuple(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document public API or make them private if they shouldn't be public.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's what you meant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also just set __all__ right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but then you'll see that you can't have these functions that look public and are not in __all__. So you will have to prepend them with _ anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why will I see? It seems to work for me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try python test/test_public_bindings.py, I would expect that it will fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to work for me 🤔

Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.


[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok if CI is ok!

@Chillee
Copy link
Contributor Author

Chillee commented Mar 2, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

def forward(ctx, *args):
assert(self.parents[-1] == name)
self.parents.pop()
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question: why are we calling clone() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the output of an autograd.Function is a view of the input, there are some restrictions on whether you can use inplace operations or not.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.

Pull Request resolved: pytorch/pytorch#95751
Approved by: https://github.com/ngimel, https://github.com/albanD
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.

Pull Request resolved: pytorch/pytorch#95751
Approved by: https://github.com/ngimel, https://github.com/albanD
ydwu4 pushed a commit to ydwu4/pytorch that referenced this pull request Mar 10, 2023
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.

Pull Request resolved: pytorch#95751
Approved by: https://github.com/ngimel, https://github.com/albanD
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
Overall, an example usage. Note that this *also* captures backwards FLOPs.
```
import torchvision.models as models
import torch
from torch.utils.flop_counter import FlopCounterMode

inp = torch.randn(1, 3, 224, 224, device='cpu')
mod = models.resnet18()

flop_counter = FlopCounterMode(mod, depth=1)
with flop_counter:
    mod(inp).sum().backward()
```
<img width="326" alt="image" src="https://user-images.githubusercontent.com/6355099/222023068-3491e405-f195-4e11-b679-36b19a1380c7.png">

You can control the depth of the module hierarchy with the `depth` attribute (which defaults to 2). For example, if I don't limit it, this is what it outputs.

<img width="366" alt="image" src="https://user-images.githubusercontent.com/6355099/222023306-3d880bb6-f534-4f98-bf10-83c4353acefc.png">

## Other APIs

FlopCounterMode(custom_mapping=...): Allows for custom flop counting functions
FlopCounterMode.get_table(depth=...): Explicitly get the table as a string
FlopCounterMode.flop_counts: Contains the flop information as a Dict[hierarchy: str, Dict[Op, int]]
FlopCounterMode.register_hierarchy(f, name): Allows you to register additional "hierarchies" for a function.

Pull Request resolved: pytorch#95751
Approved by: https://github.com/ngimel, https://github.com/albanD
@facebook-github-bot facebook-github-bot deleted the gh/chillee/189/head branch June 8, 2023 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: composability release notes category topic: new features topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants