Navigation Menu

Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
better display while benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jul 9, 2019
1 parent ca7c7ab commit fd9dd89
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
4 changes: 2 additions & 2 deletions _unittests/ut_onnxrt/test_onnxrt_validate_benchmark.py
Expand Up @@ -6,7 +6,7 @@
from logging import getLogger
from pandas import DataFrame
from pyquickhelper.loghelper import fLOG
from pyquickhelper.pycode import get_temp_folder, ExtTestCase
from pyquickhelper.pycode import get_temp_folder, ExtTestCase, is_travis_or_appveyor
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.testing import ignore_warnings
from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets, summary_report
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_validate_sklearn_operators_benchmark_all(self):
verbose, opset_min=10, benchmark=True,
fLOG=fLOG, runtime="onnxruntime1"):
rows.append(row)
if len(rows) > 40:
if is_travis_or_appveyor() and len(rows) > 40:
break
self.assertGreater(len(rows), 1)
df = DataFrame(rows)
Expand Down
Expand Up @@ -6,7 +6,7 @@
from logging import getLogger
from pandas import DataFrame
from pyquickhelper.loghelper import fLOG
from pyquickhelper.pycode import get_temp_folder, ExtTestCase
from pyquickhelper.pycode import get_temp_folder, ExtTestCase, is_travis_or_appveyor
from pyquickhelper.texthelper.version_helper import compare_module_version
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.testing import ignore_warnings
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_validate_sklearn_operators_all_onnxruntime(self):
for row in enumerate_validated_operator_opsets(verbose, debug=None, fLOG=fLOG,
runtime='onnxruntime2', dump_folder=temp):
rows.append(row)
if __name__ != "__main__" and len(rows) >= 30:
if is_travis_or_appveyor() and len(rows) > 30:
break

self.assertGreater(len(rows), 1)
Expand All @@ -127,6 +127,4 @@ def test_validate_sklearn_operators_all_onnxruntime(self):


if __name__ == "__main__":
TestOnnxrtValidateOnnxRuntime(
).test_validate_sklearn_operators_onnxruntime_AdaBoostRegressor()
unittest.main()
15 changes: 12 additions & 3 deletions mlprodict/onnxrt/validate.py
Expand Up @@ -576,10 +576,19 @@ def iterate():
loop = iterate()
else:
try:
from tqdm import tqdm
loop = tqdm(ops)
except ImportError:
from tqdm import trange

def iterate_tqdm():
with trange(len(ops)) as t:
for i in t:
row = ops[i]
disp = row['name'] + " " * (28 - len(row['name']))
t.set_description("%s" % disp)
yield row

loop = iterate_tqdm()

except ImportError:
loop = iterate()
else:
loop = ops
Expand Down

0 comments on commit fd9dd89

Please sign in to comment.