Skip to content

Commit 6423912

Browse files
authored
Gather more summary data when running result_analyzer.py. (#6067)
1 parent a80c1e7 commit 6423912

File tree

5 files changed

+253
-35
lines changed

5 files changed

+253
-35
lines changed

.circleci/common.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ function run_torch_xla_python_tests() {
150150
# echo "Running MNIST Test"
151151
# python test/test_train_mp_mnist_amp.py --fake_data --num_epochs=1
152152
fi
153+
elif [[ "$RUN_XLA_OP_TESTS1" == "xla_op1" ]]; then
154+
# Benchmark tests.
155+
# Only run on CPU, for xla_op1.
156+
echo "Running Benchmark tests."
157+
./benchmarks/test/run_tests.sh
153158
fi
154159
fi
155160
popd

benchmarks/result_analyzer.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -74,41 +74,32 @@ def run_csv(self):
7474
self.export_metric_report(metric_df)
7575

7676
def get_calculated_metrics(self, d, dataline):
77-
total_time = np.asarray(dataline["metrics"]["total_time"], dtype="float")
78-
d["median_total_time"] = np.median(total_time)
79-
per_iter_time = np.asarray(
80-
dataline["metrics"]["per_iter_time"], dtype="float")
81-
d["median_per_iter_time"] = np.median(per_iter_time)
82-
if dataline["experiment"]["xla"]:
83-
trace_per_iter_time = np.asarray(
84-
dataline["metrics"]["trace_per_iter_time"], dtype="float")
85-
d["xla_median_trace_per_iter_time"] = np.median(trace_per_iter_time)
86-
d["xla_compile_time"] = np.max(total_time) - np.median(total_time)
87-
else:
88-
d["xla_median_trace_per_iter_time"] = -1
89-
d["xla_compile_time"] = -1
90-
91-
if "total_cpu_time_s" in dataline["metrics"]:
92-
total_cpu_time = np.asarray(
93-
dataline["metrics"]["total_cpu_time_s"], dtype="float")
94-
d["median_total_cpu_time_s"] = np.median(total_cpu_time)
95-
if "per_iter_cpu_time_s" in dataline["metrics"]:
96-
per_iter_cpu_time = np.asarray(
97-
dataline["metrics"]["per_iter_cpu_time_s"], dtype="float")
98-
d["median_per_iter_cpu_time_s"] = np.median(per_iter_cpu_time)
99-
if "total_cuda_time_s" in dataline["metrics"]:
100-
total_cuda_time = np.asarray(
101-
dataline["metrics"]["total_cuda_time_s"], dtype="float")
102-
d["median_total_cuda_time_s"] = np.median(total_cuda_time)
103-
if "per_iter_cuda_time_s" in dataline["metrics"]:
104-
per_iter_cuda_time = np.asarray(
105-
dataline["metrics"]["per_iter_cuda_time_s"], dtype="float")
106-
d["median_per_iter_cuda_time_s"] = np.median(per_iter_cuda_time)
107-
108-
if dataline["experiment"]["dynamo"]:
109-
d["dynamo_compile_time"] = np.max(total_time) - np.median(total_time)
110-
else:
111-
d["dynamo_compile_time"] = -1
77+
MAX_TOTAL_TIME = f"{np.max.__name__}_total_time"
78+
MEDIAN_TOTAL_TIME = f"{np.median.__name__}_total_time"
79+
80+
for metric, raw_values in dataline["metrics"].items():
81+
values = np.asarray(raw_values, dtype="float")
82+
83+
is_valid = (
84+
dataline["experiment"]["xla"] or metric != "trace_per_iter_time")
85+
86+
for fn in (np.min, np.median, np.max):
87+
d[f"{fn.__name__}_{metric}"] = fn(values) if is_valid else -1
88+
89+
# Remove first measurement.
90+
# Assumption: the first measurement has tracing + compilation times
91+
# embedded into it. Therefore, we remove it from our data for computing
92+
# the average and standard deviation.
93+
skip_head = values[1:]
94+
95+
if len(skip_head) > 0:
96+
for fn in (np.mean, np.std):
97+
d[f"{fn.__name__}_{metric}"] = fn(skip_head) if is_valid else -1
98+
99+
compile_time = d[MAX_TOTAL_TIME] - d[MEDIAN_TOTAL_TIME]
100+
d["dynamo_compile_time"] = compile_time if dataline["experiment"][
101+
"dynamo"] else -1
102+
d["xla_compile_time"] = compile_time if dataline["experiment"]["xla"] else -1
112103
return d
113104

114105
# TODO: handle error message properly (database length restriction)

benchmarks/test/example.json

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
{
2+
"model": {
3+
"suite_name": "torchbench",
4+
"model_name": "DALLE2_pytorch"
5+
},
6+
"experiment": {
7+
"experiment_name": "run_all",
8+
"accelerator": "cuda",
9+
"accelerator_model": "One of NVIDIA GeForce RTX 2060, NVIDIA GeForce RTX 2060",
10+
"xla": "PJRT",
11+
"xla_flags": null,
12+
"dynamo": "openxla",
13+
"test": "eval",
14+
"batch_size": 1
15+
},
16+
"repeat": 10,
17+
"iterations_per_run": 1,
18+
"metrics": {
19+
"total_cpu_time_s": [
20+
81.853362,
21+
0.065951,
22+
0.056186,
23+
0.055567,
24+
0.055391,
25+
0.055835,
26+
0.055767,
27+
0.058623,
28+
0.055612,
29+
0.058594
30+
],
31+
"total_cuda_time_s": [
32+
81.852574,
33+
0.065956,
34+
0.056192,
35+
0.055573,
36+
0.055396,
37+
0.055841,
38+
0.055773,
39+
0.058629,
40+
0.055617,
41+
0.0586
42+
],
43+
"per_iter_cpu_time_s": [
44+
81.853362,
45+
0.065951,
46+
0.056186,
47+
0.055567,
48+
0.055391,
49+
0.055835,
50+
0.055767,
51+
0.058623,
52+
0.055612,
53+
0.058594
54+
],
55+
"per_iter_cuda_time_s": [
56+
81.852574,
57+
0.065956,
58+
0.056192,
59+
0.055573,
60+
0.055396,
61+
0.055841,
62+
0.055773,
63+
0.058629,
64+
0.055617,
65+
0.0586
66+
],
67+
"total_time": [
68+
120.4606251809746,
69+
0.08297968655824661,
70+
0.0747979823499918,
71+
0.07257041148841381,
72+
0.0746086947619915,
73+
0.07293416373431683,
74+
0.07472928613424301,
75+
0.07585464790463448,
76+
0.07447021268308163,
77+
0.07592942006886005
78+
],
79+
"per_iter_time": [
80+
120.4606251809746,
81+
0.08297968655824661,
82+
0.0747979823499918,
83+
0.07257041148841381,
84+
0.0746086947619915,
85+
0.07293416373431683,
86+
0.07472928613424301,
87+
0.07585464790463448,
88+
0.07447021268308163,
89+
0.07592942006886005
90+
],
91+
"trace_per_iter_time": [
92+
81.8553378880024,
93+
0.06609796732664108,
94+
0.05630519799888134,
95+
0.05567726120352745,
96+
0.055552588775753975,
97+
0.05595448426902294,
98+
0.05587966553866863,
99+
0.058729616925120354,
100+
0.05571739934384823,
101+
0.058703068643808365
102+
],
103+
"single_value": [1]
104+
},
105+
"outputs_file": null
106+
}

benchmarks/test/run_tests.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
BASEDIR="$(dirname $(dirname $(realpath $0)))"
2+
export PYTHONPATH="$BASEDIR"
3+
4+
function run_test {
5+
pushd "$BASEDIR"
6+
python3 "$@"
7+
popd
8+
}
9+
10+
if [[ "$RUN_XLA_OP_TESTS1" == "xla_op1" ]]; then
11+
run_test test/test_result_analyzer.py
12+
fi
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import argparse
2+
import functools
3+
import numpy
4+
import unittest
5+
import os
6+
7+
from result_analyzer import ResultAnalyzer, parse_args
8+
9+
fns_whole = (numpy.min, numpy.median, numpy.max)
10+
fns_skip_head = (numpy.mean, numpy.std)
11+
fns_all = fns_whole + fns_skip_head
12+
13+
14+
def apply(fn, data):
15+
return fn(data)
16+
17+
18+
def apply_skip_head(fn, data):
19+
return fn(data[1:])
20+
21+
22+
@functools.cache
23+
def get_dirname():
24+
return os.path.dirname(__file__)
25+
26+
27+
@functools.cache
28+
def get_dataline():
29+
import json
30+
example_json = os.path.join(get_dirname(), "example.json")
31+
with open(example_json, "r") as f:
32+
return json.load(f)
33+
34+
35+
class TestResultAnalyzer(unittest.TestCase):
36+
37+
def _key(self, fn, metric):
38+
return f"{fn.__name__}_{metric}"
39+
40+
def _check(self, dataline, output, fns, metric, output_value_fn):
41+
for fn in fns:
42+
key = self._key(fn, metric)
43+
self.assertIn(key, output)
44+
self.assertEqual(output[key],
45+
output_value_fn(fn, dataline["metrics"][metric]))
46+
47+
def _test_calculate_metrics(self, xla, dynamo):
48+
dataline = get_dataline()
49+
dataline["experiment"]["xla"] = xla
50+
dataline["experiment"]["dynamo"] = dynamo
51+
52+
r = ResultAnalyzer(parse_args(["--output-dirname", get_dirname()]))
53+
output = r.get_calculated_metrics({}, dataline)
54+
55+
# Check that output has data for each metric, summarized by
56+
# each of its corresponding summary functions.
57+
58+
# - metrics with more than one measurement
59+
for metric in ("total_cpu_time_s", "total_cuda_time_s",
60+
"per_iter_cpu_time_s", "per_iter_cuda_time_s", "total_time",
61+
"per_iter_time"):
62+
self._check(dataline, output, fns_whole, metric, apply)
63+
self._check(dataline, output, fns_skip_head, metric, apply_skip_head)
64+
65+
# - single_value: since it has only one value, we only check it for
66+
# fns_whole set of statistical functions
67+
self._check(dataline, output, fns_whole, "single_value", apply)
68+
69+
# Check that there are is no mean and std for single-valued timings.
70+
for fn in fns_skip_head:
71+
self.assertNotIn(self._key(fn, "single_value"), output)
72+
73+
return output, dataline
74+
75+
def test_calculate_metrics_inductor(self):
76+
output, _ = self._test_calculate_metrics(xla=None, dynamo="inductor")
77+
78+
# There should be a dynamo_compile_time key, if it's not an XLA run.
79+
self.assertIn("dynamo_compile_time", output)
80+
81+
# For all trace_per_iter_time summary data inside output, all of them
82+
# should be -1.
83+
for fn in fns_all:
84+
k = self._key(fn, "trace_per_iter_time")
85+
86+
# It's ok not to have it in output, since it's not an XLA data anyway.
87+
if k in output:
88+
self.assertEqual(output[k], -1)
89+
90+
def test_calculate_metrics_xla(self):
91+
output, dataline = self._test_calculate_metrics(
92+
xla="PJRT", dynamo="openxla")
93+
94+
# There should be an xla_compile_time key.
95+
self.assertIn("xla_compile_time", output)
96+
97+
# The trace_per_iter_time summary data should be populated.
98+
self._check(dataline, output, fns_whole, "trace_per_iter_time", apply)
99+
self._check(dataline, output, fns_skip_head, "trace_per_iter_time",
100+
apply_skip_head)
101+
102+
103+
if __name__ == "__main__":
104+
unittest.main()

0 commit comments

Comments
 (0)