Skip to content

Commit

Permalink
wip stats
Browse files Browse the repository at this point in the history
lint

ghstack-source-id: ce037814c385b03ae7180b9af7cbd0d2d11ed20e
Pull Request resolved: #93013

rm unused

stats fix

Undo
  • Loading branch information
voznesenskym committed Jan 25, 2023
1 parent e292ddf commit 8344bb6
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 34 deletions.
7 changes: 6 additions & 1 deletion benchmarks/dynamo/common.py
Expand Up @@ -1518,9 +1518,14 @@ def run_one_model(
)
print(status)
if self.args.timing:
from torch._dynamo.utils import print_time_report
from torch._dynamo.utils import op_count, print_time_report
from torch.utils._stats import simple_call_counter

print_time_report()
stats = f"STATS: call_* op count: {op_count}"
for key, value in simple_call_counter.items():
stats = f"{stats} | {key}:{value}"
print(stats)

end_calls_captured = torch._dynamo.utils.counters["stats"]["calls_captured"]
end_unique_graphs = torch._dynamo.utils.counters["stats"]["unique_graphs"]
Expand Down
59 changes: 57 additions & 2 deletions benchmarks/dynamo/parse_logs.py
Expand Up @@ -46,7 +46,24 @@ def chunker(seq, size):
i = 0

out = csv.writer(sys.stdout, dialect="excel")
out.writerow(["", hash, "", "", "", "", gist_url])
out.writerow(
[
"",
hash,
"",
"",
"",
"",
gist_url,
"frame_time",
"backend_time",
"total_ops",
"fake_tensor_dispatch_calls",
"proxy_torch_dispatch_calls",
"time_per_op",
"dispatches_per_op"
]
)

# Sometimes backtraces will be in third party code, which results
# in very long file names. Delete the absolute path in this case.
Expand Down Expand Up @@ -130,6 +147,29 @@ def normalize_file(f):
if len(split_str) == 2:
backend_time = float(split_str[1])
frame_time = float(split_str[0].split("entire_frame_compile:")[1])

tot_ops = None
fm_dispatches = None
pm_dispatches = None
if "STATS:" in log:
result = re.search("STATS:(.*)\n", log).group(1)
# call_* op count: 970 | FakeTensor.__torch_dispatch__:35285 | ProxyTorchDispatchMode.__torch_dispatch__:13339
split_all = result.split("|")

if len(split_all) == 3:
tot_ops = int(split_all[0].split("call_* op count:")[1])
fm_dispatches = int(split_all[1].split("FakeTensor.__torch_dispatch__:")[1])
pm_dispatches = int(
split_all[2].split("ProxyTorchDispatchMode.__torch_dispatch__:")[1]
)
time_per_op = None
if frame_time is not None and tot_ops is not None:
time_per_op = frame_time / tot_ops * 1000 # ms

dispatches_per_op = None
if fm_dispatches is not None and pm_dispatches is not None and tot_ops is not None:
dispatches_per_op = (fm_dispatches + pm_dispatches) / tot_ops

# If the context string is too long, don't put it in the CSV.
# This is a hack to try to make it more likely that Google Sheets will
# offer to split columns
Expand All @@ -143,7 +183,22 @@ def normalize_file(f):
context = ""

out.writerow(
[bench, name, "", r, component, context, explain, frame_time, backend_time]
[
bench,
name,
"",
r,
component,
context,
explain,
frame_time,
backend_time,
tot_ops,
fm_dispatches,
pm_dispatches,
time_per_op,
dispatches_per_op,
]
)
i += 1

Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/output_graph.py
Expand Up @@ -627,6 +627,11 @@ def compile_and_call_fx_graph(self, tx, rv, root):

@dynamo_timed(phase_name="backend_compile")
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
tot = 0
for node in gm.graph.nodes:
if node.op in ("call_function", "call_method", "call_module"):
tot += 1
torch._dynamo.utils.increment_op_count(tot)
try:
name = (
self.compiler_fn.__name__
Expand Down
9 changes: 8 additions & 1 deletion torch/_dynamo/utils.py
Expand Up @@ -50,7 +50,6 @@
# profiling compilation time
compilation_metrics = collections.OrderedDict()


timer_counter = itertools.count()


Expand Down Expand Up @@ -103,6 +102,14 @@ def reset_frame_count():
curr_frame = 0


op_count = 0


def increment_op_count(cnt):
global op_count
op_count += cnt


# Print a report of time spent so far
# Ex:
# TIMING:
Expand Down
2 changes: 2 additions & 0 deletions torch/_subclasses/fake_tensor.py
Expand Up @@ -20,6 +20,7 @@
from torch.utils._python_dispatch import TorchDispatchMode

from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
from torch.utils._stats import count
from torch.utils.weak import WeakIdRef

pytree = torch.utils._pytree
Expand Down Expand Up @@ -623,6 +624,7 @@ def __repr__(self):
return f"FakeTensor({self_repr}, {self.fake_device})"

@classmethod
@count
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# need to handle here to avoid infinite recursion
# see [in_kernel_invocation]
Expand Down
2 changes: 2 additions & 0 deletions torch/fx/experimental/proxy_tensor.py
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass
import weakref
import operator
from torch.utils._stats import count

from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
from torch._subclasses import FakeTensor
Expand Down Expand Up @@ -477,6 +478,7 @@ def __init__(self, tracer, tracing_mode):
self.trace_state = {}
self._managers = []

@count
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with self.sym_mode.enable(False):
return self.inner_torch_dispatch(func, types, args, kwargs)
Expand Down
27 changes: 0 additions & 27 deletions torch/nn/functional.py
Expand Up @@ -4831,33 +4831,6 @@ def _in_projection(
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

scaled_dot_product_attention = _add_docstr(
torch._C._nn.scaled_dot_product_attention, r"""
Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified.
Args:
query (Tensor): Query tensor; shape (N, ..., L, E)
key (Tensor): Key tensor; shape (N, ..., S, E)
value (Tensor): Value tensor; shape (N, ..., S, E)
attn_mask (optional Tensor): Attention mask; shape (N, ..., L, S) or (L, S). Currently, only a boolean mask
is supported, where a value of True indicates that the element *should* take part in attention.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If true, assumes causal attention masking and ignores attn_mask.
Returns a tuple containing:
output (Tensor): Attention output; shape (N, ..., L, E)
Shape legend:
N: Batch size
...: Any number of other batch dimensions (optional)
S: Source sequence length
L: Target sequence lengthE: Embedding dimension
""")


def _scaled_dot_product_attention(
query: Tensor,
Expand Down
6 changes: 3 additions & 3 deletions torch/overrides.py
Expand Up @@ -295,8 +295,8 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._conj_physical,
Tensor._neg_view,
Tensor._is_zerotensor,
Tensor._is_all_true,
Tensor._is_any_true,
# Tensor._is_all_true,
# Tensor._is_any_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
}
Expand Down Expand Up @@ -861,7 +861,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.selu: lambda input, inplace=False: -1,
torch.nn.functional.silu: lambda input, inplace=False: -1,
torch.nn.functional.mish: lambda input, inplace=False: -1,
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
# torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
Expand Down
16 changes: 16 additions & 0 deletions torch/utils/_stats.py
@@ -0,0 +1,16 @@
# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
# AND SCRUB AWAY TORCH NOTIONS THERE.
import collections
import functools

simple_call_counter = collections.OrderedDict()

def count(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if fn.__qualname__ not in simple_call_counter:
simple_call_counter[fn.__qualname__] = 0
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
return fn(*args, **kwargs)
return wrapper

0 comments on commit 8344bb6

Please sign in to comment.