Skip to content

Commit

Permalink
Add option to flop counter formula registration to get raw values (#1…
Browse files Browse the repository at this point in the history
…10591)

Pull Request resolved: #110591
Approved by: https://github.com/awgu
ghstack dependencies: #110501, #110504
  • Loading branch information
Chillee authored and pytorchmergebot committed Oct 5, 2023
1 parent 9e72c9c commit ada6550
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
11 changes: 11 additions & 0 deletions test/test_flop_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ def test_custom(self):

self.assertExpectedInline(get_total_flops(mode), """5""")

def count(*args, out):
return out.numel()
count._get_raw = True

mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
with mode:
a = T(4, 5)
a + a

self.assertExpectedInline(get_total_flops(mode), """20""")

def test_noop(self):
mode = FlopCounterMode()
with mode:
Expand Down
20 changes: 16 additions & 4 deletions torch/utils/flop_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.hooks import RemovableHandle
from torch._decomp import register_decomposition
from math import prod
from functools import wraps



Expand All @@ -21,8 +22,17 @@ def get_shape(i):

flop_registry: Dict[Any, Any] = {}

def register_flop_formula(targets):
def shape_wrapper(f):
@wraps(f)
def nf(*args, out=None, **kwargs):
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
return f(*args, out_shape=out_shape, **kwargs)
return nf

def register_flop_formula(targets, get_raw=False):
def register_fun(flop_formula):
if not get_raw:
flop_formula = shape_wrapper(flop_formula)
register_decomposition(targets, registry=flop_registry, unsafe=True)(flop_formula)
return flop_formula

Expand Down Expand Up @@ -273,7 +283,10 @@ def __init__(
self.mods = mods
# Keys will include the modules in `mods` and their submodules
self._module_to_forward_hook_handles: Dict[nn.Module, _ForwardHookHandles] = {}
self.flop_registry = {**flop_registry, **custom_mapping}
self.flop_registry = {
**flop_registry,
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
}

def _register_forward_hooks(self):
if self.mods is None:
Expand Down Expand Up @@ -439,8 +452,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
func_packet = func._overloadpacket
if func_packet in self.flop_registry:
flop_count_func = self.flop_registry[func_packet]
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator]
flop_count = flop_count_func(*args, **kwargs, out=out) # type: ignore[operator]
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count

Expand Down

0 comments on commit ada6550

Please sign in to comment.