Skip to content

Commit

Permalink
split out flop counting its own method (#125061)
Browse files Browse the repository at this point in the history
Summary: Modularizing code for reuse by splitting __torch_dispatch__ to move flop counting to its own method.

Test Plan: unit tests

Differential Revision: D56644523

Pull Request resolved: #125061
Approved by: https://github.com/842974287
  • Loading branch information
mrajpal authored and pytorchmergebot committed Apr 29, 2024
1 parent e5e623a commit da44d2f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torch/utils/flop_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,9 @@ def __exit__(self, *args):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
func_packet = func._overloadpacket
return self._count_flops(func._overloadpacket, out, args, kwargs)

def _count_flops(self, func_packet, out, args, kwargs):
if func_packet in self.flop_registry:
flop_count_func = self.flop_registry[func_packet]
flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
Expand Down

0 comments on commit da44d2f

Please sign in to comment.