Skip to content

Commit

Permalink
wip stats
Browse files Browse the repository at this point in the history
lint

ghstack-source-id: 9bfc70d6a5fc3288cc749648356d730ed90199da
Pull Request resolved: #93013

rm unused
  • Loading branch information
voznesenskym committed Jan 25, 2023
1 parent e292ddf commit 7b70431
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 6 deletions.
7 changes: 6 additions & 1 deletion benchmarks/dynamo/common.py
Expand Up @@ -1518,9 +1518,14 @@ def run_one_model(
)
print(status)
if self.args.timing:
from torch._dynamo.utils import print_time_report
from torch._dynamo.utils import op_count, print_time_report
from torch.utils._stats import simple_call_counter

print_time_report()
stats = f"STATS: call_* op count: {op_count}"
for key, value in simple_call_counter.items():
stats = f"{stats} | {key}:{value}"
print(stats)

end_calls_captured = torch._dynamo.utils.counters["stats"]["calls_captured"]
end_unique_graphs = torch._dynamo.utils.counters["stats"]["unique_graphs"]
Expand Down
48 changes: 46 additions & 2 deletions benchmarks/dynamo/parse_logs.py
Expand Up @@ -46,7 +46,22 @@ def chunker(seq, size):
i = 0

out = csv.writer(sys.stdout, dialect="excel")
out.writerow(["", hash, "", "", "", "", gist_url])
out.writerow(
[
"",
hash,
"",
"",
"",
"",
gist_url,
"frame_time",
"backend_time",
"total_ops",
"fake_tensor_dispatch_calls",
"proxy_torch_dispatch_calls",
]
)

# Sometimes backtraces will be in third party code, which results
# in very long file names. Delete the absolute path in this case.
Expand Down Expand Up @@ -130,6 +145,22 @@ def normalize_file(f):
if len(split_str) == 2:
backend_time = float(split_str[1])
frame_time = float(split_str[0].split("entire_frame_compile:")[1])

tot_ops = None
fm_dispatches = None
pm_dispatches = None
if "STATS:" in log:
result = re.search("STATS:(.*)\n", log).group(1)
# call_* op count: 970 | FakeTensor.__torch_dispatch__:35285 | ProxyTorchDispatchMode.__torch_dispatch__:13339
split_all = result.split("|")

if len(split_all) == 3:
tot_ops = int(split_all[0].split("call_* op count:")[1])
fm_dispatches = int(split_all[0].split("FakeTensor.__torch_dispatch__:")[1])
pm_dispatches = int(
split_all[0].split("ProxyTorchDispatchMode.__torch_dispatch__:")[1]
)

# If the context string is too long, don't put it in the CSV.
# This is a hack to try to make it more likely that Google Sheets will
# offer to split columns
Expand All @@ -143,7 +174,20 @@ def normalize_file(f):
context = ""

out.writerow(
[bench, name, "", r, component, context, explain, frame_time, backend_time]
[
bench,
name,
"",
r,
component,
context,
explain,
frame_time,
backend_time,
tot_ops,
fm_dispatches,
pm_dispatches,
]
)
i += 1

Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/output_graph.py
Expand Up @@ -627,6 +627,11 @@ def compile_and_call_fx_graph(self, tx, rv, root):

@dynamo_timed(phase_name="backend_compile")
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
tot = 0
for node in gm.graph.nodes:
if node.op in ("call_function", "call_method", "call_module"):
tot += 1
torch._dynamo.utils.increment_op_count(tot)
try:
name = (
self.compiler_fn.__name__
Expand Down
9 changes: 8 additions & 1 deletion torch/_dynamo/utils.py
Expand Up @@ -50,7 +50,6 @@
# profiling compilation time
compilation_metrics = collections.OrderedDict()


timer_counter = itertools.count()


Expand Down Expand Up @@ -103,6 +102,14 @@ def reset_frame_count():
curr_frame = 0


op_count = 0


def increment_op_count(cnt):
global op_count
op_count += cnt


# Print a report of time spent so far
# Ex:
# TIMING:
Expand Down
10 changes: 8 additions & 2 deletions torch/_ops.py
Expand Up @@ -239,6 +239,7 @@ def __init__(self, overloadpacket, op, op_dk, schema, tags):
self._schema = schema
self._overloadpacket = overloadpacket
self._tags = tags
self.dk_res = None
self._overloadname = (
"default" if schema.overload_name == "" else schema.overload_name
)
Expand Down Expand Up @@ -301,9 +302,14 @@ def decompose(self, *args, **kwargs):
# apply Python CompositeImplicitAutograd *before* tracing
# using Python dispatcher (also taking advantage of the autograd
# formula). But it's included for completeness
return self.py_kernels[dk](*args, **kwargs)
self.dk_res = self.py_kernels[dk](*args, **kwargs)
return self.dk_res
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
return self._op_dk(dk, *args, **kwargs)
self.dk_res = self._op_dk(dk, *args, **kwargs)
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule

# with no_dispatch():
return self.dk_res
else:
return NotImplemented

Expand Down
2 changes: 2 additions & 0 deletions torch/_subclasses/fake_tensor.py
Expand Up @@ -20,6 +20,7 @@
from torch.utils._python_dispatch import TorchDispatchMode

from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
from torch.utils._stats import count
from torch.utils.weak import WeakIdRef

pytree = torch.utils._pytree
Expand Down Expand Up @@ -623,6 +624,7 @@ def __repr__(self):
return f"FakeTensor({self_repr}, {self.fake_device})"

@classmethod
@count
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# need to handle here to avoid infinite recursion
# see [in_kernel_invocation]
Expand Down
2 changes: 2 additions & 0 deletions torch/fx/experimental/proxy_tensor.py
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass
import weakref
import operator
from torch.utils._stats import count

from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
from torch._subclasses import FakeTensor
Expand Down Expand Up @@ -477,6 +478,7 @@ def __init__(self, tracer, tracing_mode):
self.trace_state = {}
self._managers = []

@count
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with self.sym_mode.enable(False):
return self.inner_torch_dispatch(func, types, args, kwargs)
Expand Down
16 changes: 16 additions & 0 deletions torch/utils/_stats.py
@@ -0,0 +1,16 @@
# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
# AND SCRUB AWAY TORCH NOTIONS THERE.
import collections
import functools

simple_call_counter = collections.OrderedDict()

def count(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if fn.__qualname__ not in simple_call_counter:
simple_call_counter[fn.__qualname__] = 0
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
return fn(*args, **kwargs)
return wrapper

0 comments on commit 7b70431

Please sign in to comment.