From 614673e6d8261e6bc0484d9963c99d7f1878c46b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 16 Apr 2019 17:32:08 +0200 Subject: [PATCH] refactoring, add parameter number --- _unittests/ut_benchmark/test_bench_helper.py | 19 +--------- .../ut_benchmark/test_benchmark_perf.py | 38 ++++++------------- src/pymlbenchmark/benchmark/benchmark_perf.py | 30 +++++++++++---- 3 files changed, 35 insertions(+), 52 deletions(-) diff --git a/_unittests/ut_benchmark/test_bench_helper.py b/_unittests/ut_benchmark/test_bench_helper.py index 5fc5e36..5cb1ec7 100644 --- a/_unittests/ut_benchmark/test_bench_helper.py +++ b/_unittests/ut_benchmark/test_bench_helper.py @@ -2,28 +2,11 @@ """ @brief test log(time=2s) """ -import sys import os import unittest import pandas from pyquickhelper.pycode import ExtTestCase - - -try: - import src -except ImportError: - path = os.path.normpath( - os.path.abspath( - os.path.join( - os.path.split(__file__)[0], - "..", - ".."))) - if path not in sys.path: - sys.path.append(path) - import src - - -from src.pymlbenchmark.benchmark.bench_helper import bench_pivot +from pymlbenchmark.benchmark.bench_helper import bench_pivot class TestBenchHelper(ExtTestCase): diff --git a/_unittests/ut_benchmark/test_benchmark_perf.py b/_unittests/ut_benchmark/test_benchmark_perf.py index 63c6909..4b53e89 100644 --- a/_unittests/ut_benchmark/test_benchmark_perf.py +++ b/_unittests/ut_benchmark/test_benchmark_perf.py @@ -4,30 +4,13 @@ """ import io import contextlib -import sys import os import unittest import pickle import numpy from pyquickhelper.pycode import ExtTestCase, get_temp_folder - - -try: - import src -except ImportError: - path = os.path.normpath( - os.path.abspath( - os.path.join( - os.path.split(__file__)[0], - "..", - ".."))) - if path not in sys.path: - sys.path.append(path) - import src - - -from src.pymlbenchmark.benchmark import BenchPerf, BenchPerfTest -from src.pymlbenchmark.datasets import random_binary_classification +from pymlbenchmark.benchmark import BenchPerf, BenchPerfTest +from pymlbenchmark.datasets import random_binary_classification class TestBenchPerf(ExtTestCase): @@ -167,14 +150,15 @@ def validate(self, results, **kwargs): pafter = dict(method=["predict", "predict_proba"], N=[1, 10]) bp = BenchPerf(pbefore, pafter, myBenchPerfTest) - list(bp.enumerate_run_benchs()) - name = os.path.join(temp, "BENCH-ERROR-myBenchPerfTest-0.pk") - with open(name, 'rb') as f: - content = pickle.load(f) - self.assertIsInstance(content, dict) - self.assertIn('msg', content) - self.assertIn('data', content) - self.assertIsInstance(content['data'], dict) + for number in [1, 2]: + list(bp.enumerate_run_benchs(repeat=5, number=number)) + name = os.path.join(temp, "BENCH-ERROR-myBenchPerfTest-0.pk") + with open(name, 'rb') as f: + content = pickle.load(f) + self.assertIsInstance(content, dict) + self.assertIn('msg', content) + self.assertIn('data', content) + self.assertIsInstance(content['data'], dict) if __name__ == "__main__": diff --git a/src/pymlbenchmark/benchmark/benchmark_perf.py b/src/pymlbenchmark/benchmark/benchmark_perf.py index 91c9e5c..6e1aa0b 100644 --- a/src/pymlbenchmark/benchmark/benchmark_perf.py +++ b/src/pymlbenchmark/benchmark/benchmark_perf.py @@ -136,7 +136,8 @@ def enumerate_tests(self, options): yield row def enumerate_run_benchs(self, repeat=10, verbose=False, - stop_if_error=True, validate=True): + stop_if_error=True, validate=True, + number=1): """ Runs the benchmark. @@ -146,6 +147,8 @@ def enumerate_run_benchs(self, repeat=10, verbose=False, @param stop_if_error by default, it stops when method *validate* fails, if False, the function stores the exception @param validate compare the outputs against the baseline + @param number number of times to call the same function, + the method then measure this number calls @return yields dictionaries with all the metrics """ all_opts = self.pbefore.copy() @@ -180,6 +183,7 @@ def enumerate_run_benchs(self, repeat=10, verbose=False, raise ValueError( "Method *data* must return a list or a tuple.") obs["repeat"] = len(data) + obs["number"] = number results = [] stores = [] @@ -200,15 +204,27 @@ def enumerate_run_benchs(self, repeat=10, verbose=False, f1, f2 = f for dt in data: dt2 = f1(*dt) - st = time_perf() - r = f2(*dt2) - d = time_perf() - st + if number == 1: + st = time_perf() + r = f2(*dt2) + d = time_perf() - st + else: + st = time_perf() + for _ in range(number): + r = f2(*dt2) + d = time_perf() - st times.append(d) else: for dt in data: - st = time_perf() - r = f(*dt) - d = time_perf() - st + if number == 1: + st = time_perf() + r = f(*dt) + d = time_perf() - st + else: + st = time_perf() + for _ in range(number): + r = f(*dt) + d = time_perf() - st times.append(d) results.append((fct, r))