Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions scripts/grab_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#!/usr/bin/env python
# Parses the output of XLA_SAVE_TENSORS_FILE and produces statistics about graph
# types and Python frames.

from __future__ import print_function

import argparse
import collections
import difflib
import os
import re
import shutil
import sys

GraphInfo = collections.namedtuple('GraphInfo', 'id, graph, ngraph, frame')


def save_graph(graph, path):
with open(path, 'w') as fd:
fd.write('\n'.join(graph))


def normalize(graph):
# %397 = f32[128]{0} xla::cross_replica_sum(%396), scale=0.125, groups=()
ngraph = []
for line in graph:
m = re.match(r'(\s*)%\d+\s*=\s*(.*::[^(]+\()[^)]*(.*)', line)
if m:
line = m.group(1) + m.group(2) + m.group(3)
ngraph.append(line)
return ngraph


def prase_graphs(gfile, dest_dir, graphs=None):
if dest_dir:
if os.path.isdir(dest_dir):
raise RuntimeError('Folder already exists: {}'.format(dest_dir))
os.mkdir(dest_dir)

if graphs is None:
graphs = []
graph, frame, last_frame = None, None, None
for line in gfile:
line = line.rstrip('\n')
if frame is not None:
if re.match(r'\s*$', line):
last_frame = frame
frame = None
else:
frame.append(line)
elif graph is not None:
graph.append(line)
m = re.match(r'}\s*$', line)
if m:
if dest_dir:
save_graph(graph,
os.path.join(dest_dir, 'graph_{:04d}'.format(len(graphs))))
graphs.append(
GraphInfo(
id=len(graphs),
graph=graph,
ngraph=normalize(graph),
frame=last_frame))
graph = None
last_frame = None
else:
m = re.match(r'TensorsGraphInfo:', line)
if m:
frame = []
else:
m = re.match(r'IR {\s*', line)
if m:
graph = [line]
return graphs


def group_by_frame(graphs):
fgroup = collections.defaultdict(list)
for graph in graphs:
fgroup['\n'.join(graph.frame)].append(graph)
return fgroup


def set_add(s, i):
plen = len(s)
s.add(i)
return len(s) > plen


def diff_graphs(g1, g2, name1, name2, prefix=''):
diff = difflib.unified_diff(g1.ngraph, g2.ngraph, name1, name2)
result = ''
for line in diff:
if line[-1] != '\n':
result += '{}{}\n'.format(prefix, line)
else:
result += '{}{}'.format(prefix, line)
return result


def process_graphs(args):
if not args.files:
graphs = prase_graphs(sys.stdin, args.graphdir)
else:
graphs = []
for fname in args.files:
with open(fname, 'r') as fd:
prase_graphs(fd, args.graphdir, graphs=graphs)
print('Parsed {} graph(s)'.format(len(graphs)))
fgroup = group_by_frame(graphs)
print('{} frame group(s)'.format(len(fgroup)))
for f in fgroup.keys():
fgraphs = fgroup[f]
uniq = set()
uniq_graphs = []
for graph in fgraphs:
if set_add(uniq, '\n'.join(graph.graph)):
uniq_graphs.append(graph)
print('Frame has {} graph(s) ({} unique):\n{}\n'.format(
len(fgraphs), len(uniq), f))
for i in range(len(uniq_graphs) - 1, 0, -1):
print(' Frame {} (len={}) vs {} (len={})'.format(
i - 1, len(uniq_graphs[i - 1].graph), i, len(uniq_graphs[i].graph)))
print(
diff_graphs(
uniq_graphs[i - 1],
uniq_graphs[i],
'frame-{}'.format(i - 1),
'frame-{}'.format(i),
prefix=' '))


if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--graphdir', type=str)
args, files = arg_parser.parse_known_args()
args.files = files
process_graphs(args)
97 changes: 97 additions & 0 deletions scripts/grab_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python
# Given a log file in which the XLA metrics report has been dumped, extracts the
#different metrics across multiple points and produces data in a format which
# can be graphed.
# Can also produce data which is a combination of other metric, using the
# --synth parameters:
#
# --synth 'LiveDataHandles:CreateDataHandles - DestroyDataHandles'
#

from __future__ import print_function

import argparse
import collections
import re
import sys


def parse_metrics(lines):
# Counter: CreateCompileHandles
# Value: 1631
metrics = collections.defaultdict(list)
metric = None
for line in lines:
if metric is not None:
m = re.match(r'\s*Value: ([^\s]+)', line)
if m:
metrics[metric].append(m.group(1))
metric = None
else:
m = re.match(r'Counter: ([^\s]+)', line)
if m:
metric = m.group(1)
return metrics


def create_metric_report(args, metric, metric_data):
print('[{}]'.format(metric))
for i, v in enumerate(metric_data):
print('{}\t{}'.format(i, v))


def process_synth(args, synth, metrics):
name, expr = synth.split(':', 1)
xvars = set()
for m in re.finditer(r'[a-zA-Z_][a-zA-Z_0-9]*', expr):
xvars.add(m.group(0))
xvars = list(xvars)
xmetrics = []
for v in xvars:
metric_data = metrics.get(v, None)
if metric_data is None:
raise RuntimeError('Unknown metric: {}'.format(v))
xmetrics.append(metric_data)
print('[{}]'.format(name))
x = 0
while True:
env = {}
for i, v in enumerate(xvars):
metric_data = xmetrics[i]
if x >= len(metric_data):
break
env[v] = float(metric_data[x])
if len(env) < len(xvars):
break
y = eval(expr, env)
print('{}\t{}'.format(x, y))
x += 1


def create_report(args, metrics):
if args.metric:
metric_data = metrics.get(args.metric, None)
if metric_data is None:
raise RuntimeError('Unknown metric: {}'.format(args.metric))
create_metric_report(args, args.metric, metric_data)
else:
for metric in metrics.keys():
create_metric_report(args, metric, metrics[metric])
for synth in args.synth:
process_synth(args, synth, metrics)


def process_metrics(args):
fd = sys.stdin if args.input is None else open(args.input, 'r')
metrics = parse_metrics(fd)
create_report(args, metrics)


if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--input', type=str)
arg_parser.add_argument('--metric', type=str)
arg_parser.add_argument('--synth', action='append', type=str)
args, files = arg_parser.parse_known_args()
args.files = files
process_metrics(args)