Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 72 additions & 35 deletions test/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import argparse
import inspect
import re
import time
import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -23,42 +24,82 @@
import torch_xla_py.xla_model as xm


def _use_result(*args):
for v in args:
v.cpu()
class BaseBench(object):

def __init__(self):
self.device = xm.xla_device()
self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0)
torch.manual_seed(42)

def bench_add_mul_div(args):
device = xm.xla_device()
a = torch.rand(8, 8)
b = torch.rand(8, 8).abs() + 1.0
xla_a = a.to(device)
xla_b = b.to(device)
for i in range(0, xu.getenv_as('ADD_MUL_DIV_LOOPS', int, 1000)):
xla_c = xla_a * xla_b - xla_a / xla_b
_use_result(xla_c)
xu.get_print_fn()(torch_xla._XLAC._xla_metrics_report())


def bench_add_mul_div_transfer(args):
device = xm.xla_device()
size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100)
a = torch.rand(size, size)
b = torch.rand(size, size).abs() + 1.0
for i in range(0, xu.getenv_as('ADD_MUL_DIV_LOOPS', int, 1000)):
xla_a = a.to(device)
xla_b = b.to(device)
def _get_parent_class(self):
return inspect.getmro(self.__class__)[0]

def setup(self, args):
pass

def bench(self):
raise RuntimeError('Not implemented')

def use_results(self, results):
for v in results:
v.cpu()

def run(self, args):
bench_name = self._get_parent_class().__name__
try:
self.setup(args)
# Do one warmup run.
self.bench()
except Exception as e:
xu.eprint('Failed running benchmark "{}": {}'.format(bench_name, e))
return
try:
start = time.time()
now = start
count = 0
while self.test_time > (now - start):
self.bench()
count += 1
now = time.time()
print('{}: {:.3f}ms per loop'.format(bench_name,
1000.0 * (now - start) / count))
xu.get_print_fn()(torch_xla._XLAC._xla_metrics_report())
except Exception as e:
xu.eprint('Failed running benchmark "{}": {}'.format(bench_name, e))


class BenchAddMulDiv(BaseBench):

def setup(self, args):
self.a = torch.rand(8, 8)
self.b = torch.rand(8, 8).abs() + 1.0
self.xla_a = self.a.to(self.device)
self.xla_b = self.b.to(self.device)

def bench(self):
xla_c = self.xla_a * self.xla_b - self.xla_a / self.xla_b
self.use_results([xla_c])


class BenchAddMulDivTransfer(BaseBench):

def setup(self, args):
self.size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100)
self.a = torch.rand(self.size, self.size)
self.b = torch.rand(self.size, self.size).abs() + 1.0

def bench(self):
xla_a = self.a.to(self.device)
xla_b = self.b.to(self.device)
xla_c = xla_a * xla_b - xla_a / xla_b
_use_result(xla_c)
xu.get_print_fn()(torch_xla._XLAC._xla_metrics_report())
self.use_results([xla_c])


def run_benchmarks(args):
benchs = {}
for name, func in inspect.getmembers(sys.modules[__name__],
inspect.isfunction):
if re.match(r'bench_', name):
benchs[name] = func
for name, cls in inspect.getmembers(sys.modules[__name__], inspect.isclass):
if re.match(r'Bench', name):
benchs[name] = cls
if args.benchs:
run_benchs = []
bench_keys = benchs.keys()
Expand All @@ -71,11 +112,8 @@ def run_benchmarks(args):
else:
run_benchs = benchs.keys()
for name in sorted(run_benchs):
with xu.TimedScope(msg='Benchmark "{}": '.format(name)):
try:
benchs[name](args)
except Exception as e:
print('Failed running benchmark "{}": {}'.format(name, e))
bench = benchs[name]()
bench.run(args)


if __name__ == '__main__':
Expand All @@ -84,7 +122,6 @@ def run_benchmarks(args):
args.benchs = benchs

torch.set_default_tensor_type('torch.FloatTensor')
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
run_benchmarks(args)
9 changes: 7 additions & 2 deletions torch_xla_py/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,16 @@ def __init__(self, msg='', printfn=eprint):
printfn = get_print_fn()
self._msg = msg
self._printfn = printfn
self._error = None

def __enter__(self):
self._start = time.time()
return self

def __exit__(self, type, value, traceback):
self._printfn('{}{:.3f}ms'.format(self._msg,
1000.0 * (time.time() - self._start)))
if self._error is None:
self._printfn('{}{:.3f}ms'.format(self._msg,
1000.0 * (time.time() - self._start)))

def set_error(self, error):
self._error = error