diff --git a/test/bench.py b/test/bench.py index 00b1c3767b43..08d6ce314b8e 100644 --- a/test/bench.py +++ b/test/bench.py @@ -12,6 +12,7 @@ import argparse import inspect import re +import time import torch import torch.nn as nn import torch.optim as optim @@ -23,42 +24,82 @@ import torch_xla_py.xla_model as xm -def _use_result(*args): - for v in args: - v.cpu() +class BaseBench(object): + def __init__(self): + self.device = xm.xla_device() + self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0) + torch.manual_seed(42) -def bench_add_mul_div(args): - device = xm.xla_device() - a = torch.rand(8, 8) - b = torch.rand(8, 8).abs() + 1.0 - xla_a = a.to(device) - xla_b = b.to(device) - for i in range(0, xu.getenv_as('ADD_MUL_DIV_LOOPS', int, 1000)): - xla_c = xla_a * xla_b - xla_a / xla_b - _use_result(xla_c) - xu.get_print_fn()(torch_xla._XLAC._xla_metrics_report()) - - -def bench_add_mul_div_transfer(args): - device = xm.xla_device() - size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100) - a = torch.rand(size, size) - b = torch.rand(size, size).abs() + 1.0 - for i in range(0, xu.getenv_as('ADD_MUL_DIV_LOOPS', int, 1000)): - xla_a = a.to(device) - xla_b = b.to(device) + def _get_parent_class(self): + return inspect.getmro(self.__class__)[0] + + def setup(self, args): + pass + + def bench(self): + raise RuntimeError('Not implemented') + + def use_results(self, results): + for v in results: + v.cpu() + + def run(self, args): + bench_name = self._get_parent_class().__name__ + try: + self.setup(args) + # Do one warmup run. + self.bench() + except Exception as e: + xu.eprint('Failed running benchmark "{}": {}'.format(bench_name, e)) + return + try: + start = time.time() + now = start + count = 0 + while self.test_time > (now - start): + self.bench() + count += 1 + now = time.time() + print('{}: {:.3f}ms per loop'.format(bench_name, + 1000.0 * (now - start) / count)) + xu.get_print_fn()(torch_xla._XLAC._xla_metrics_report()) + except Exception as e: + xu.eprint('Failed running benchmark "{}": {}'.format(bench_name, e)) + + +class BenchAddMulDiv(BaseBench): + + def setup(self, args): + self.a = torch.rand(8, 8) + self.b = torch.rand(8, 8).abs() + 1.0 + self.xla_a = self.a.to(self.device) + self.xla_b = self.b.to(self.device) + + def bench(self): + xla_c = self.xla_a * self.xla_b - self.xla_a / self.xla_b + self.use_results([xla_c]) + + +class BenchAddMulDivTransfer(BaseBench): + + def setup(self, args): + self.size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100) + self.a = torch.rand(self.size, self.size) + self.b = torch.rand(self.size, self.size).abs() + 1.0 + + def bench(self): + xla_a = self.a.to(self.device) + xla_b = self.b.to(self.device) xla_c = xla_a * xla_b - xla_a / xla_b - _use_result(xla_c) - xu.get_print_fn()(torch_xla._XLAC._xla_metrics_report()) + self.use_results([xla_c]) def run_benchmarks(args): benchs = {} - for name, func in inspect.getmembers(sys.modules[__name__], - inspect.isfunction): - if re.match(r'bench_', name): - benchs[name] = func + for name, cls in inspect.getmembers(sys.modules[__name__], inspect.isclass): + if re.match(r'Bench', name): + benchs[name] = cls if args.benchs: run_benchs = [] bench_keys = benchs.keys() @@ -71,11 +112,8 @@ def run_benchmarks(args): else: run_benchs = benchs.keys() for name in sorted(run_benchs): - with xu.TimedScope(msg='Benchmark "{}": '.format(name)): - try: - benchs[name](args) - except Exception as e: - print('Failed running benchmark "{}": {}'.format(name, e)) + bench = benchs[name]() + bench.run(args) if __name__ == '__main__': @@ -84,7 +122,6 @@ def run_benchmarks(args): args.benchs = benchs torch.set_default_tensor_type('torch.FloatTensor') - torch.manual_seed(42) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) run_benchmarks(args) diff --git a/torch_xla_py/utils.py b/torch_xla_py/utils.py index 5c8580daa01a..6badebcb728c 100644 --- a/torch_xla_py/utils.py +++ b/torch_xla_py/utils.py @@ -150,11 +150,16 @@ def __init__(self, msg='', printfn=eprint): printfn = get_print_fn() self._msg = msg self._printfn = printfn + self._error = None def __enter__(self): self._start = time.time() return self def __exit__(self, type, value, traceback): - self._printfn('{}{:.3f}ms'.format(self._msg, - 1000.0 * (time.time() - self._start))) + if self._error is None: + self._printfn('{}{:.3f}ms'.format(self._msg, + 1000.0 * (time.time() - self._start))) + + def set_error(self, error): + self._error = error