Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
Fixes issue #26, run tests in separate processes (but not parallelized)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Aug 1, 2019
1 parent 7dab4c9 commit 81536bf
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 12 deletions.
31 changes: 31 additions & 0 deletions _unittests/ut_cli/test_cli_validate_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
@brief test tree node (time=30s)
"""
import os
import unittest
from pyquickhelper.loghelper import BufferedPrint
from pyquickhelper.pycode import ExtTestCase, get_temp_folder
from mlprodict.__main__ import main


class TestCliValidateProcess(ExtTestCase):

def test_cli_validate_model_process_csv(self):
temp = get_temp_folder(__file__, "temp_validate_model_process_csv")
out1 = os.path.join(temp, "raw.csv")
out2 = os.path.join(temp, "sum.csv")
st = BufferedPrint()
main(args=["validate_runtime", "--out_raw", out1,
"--out_summary", out2, "--models",
"LogisticRegression,LinearRegression",
'-o', '10', '-op', '10', '-v', '1', '-b', '1',
'-se', '1'],
fLOG=st.fprint)
res = str(st)
self.assertIn('Linear', res)
self.assertExists(out1)
self.assertExists(out2)


if __name__ == "__main__":
unittest.main()
126 changes: 114 additions & 12 deletions mlprodict/cli/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import os
from logging import getLogger
import warnings
from multiprocessing import Pool
from pandas import DataFrame
from sklearn.exceptions import ConvergenceWarning
from ..onnxrt.validate import enumerate_validated_operator_opsets, summary_report # pylint: disable=E0402


def validate_runtime(verbose=1, opset_min=9, opset_max="",
Expand All @@ -17,7 +17,8 @@ def validate_runtime(verbose=1, opset_min=9, opset_max="",
dump_folder=None, benchmark=False,
catch_warnings=True, assume_finite=True,
versions=False, skip_models=None,
extended_list=True, fLOG=print):
extended_list=True, separate_process=False,
fLOG=print):
"""
Walks through most of :epkg:`scikit-learn` operators
or model or predictor or transformer, tries to convert
Expand Down Expand Up @@ -53,6 +54,9 @@ def validate_runtime(verbose=1, opset_min=9, opset_max="",
:epkg:`sklearn-onnx`
:param extended_list: extends the list of :epkg:`scikit-learn` converters
with converters implemented in this module
:param separate_process: run every model in a separate process,
this option must be used to run all model in one row
even if one of them is crashing
:param fLOG: logging function
.. cmdref::
Expand All @@ -78,6 +82,20 @@ def validate_runtime(verbose=1, opset_min=9, opset_max="",
python -m mlprodict validate_runtime -v 1 -o 10 -op 10 -c 1 -r onnxruntime1
-m RandomForestRegressor,DecisionTreeRegressor -out bench_onnxruntime.xlsx -b 1
"""
if separate_process:
return _validate_runtime_separate_process(
verbose=verbose, opset_min=opset_min, opset_max=opset_max,
check_runtime=check_runtime, runtime=runtime, debug=debug,
models=models, out_raw=out_raw,
out_summary=out_summary,
dump_folder=dump_folder, benchmark=benchmark,
catch_warnings=catch_warnings, assume_finite=assume_finite,
versions=versions, skip_models=skip_models,
extended_list=extended_list,
fLOG=fLOG)

from ..onnxrt.validate import enumerate_validated_operator_opsets # pylint: disable=E0402

models = None if models in (None, "") else models.strip().split(',')
skip_models = {} if skip_models in (
None, "") else skip_models.strip().split(',')
Expand All @@ -99,24 +117,32 @@ def validate_runtime(verbose=1, opset_min=9, opset_max="",
if isinstance(extended_list, str):
extended_list = extended_list in ('1', 'True', 'true')

def build_rows():
def build_rows(models_):
rows = list(enumerate_validated_operator_opsets(
verbose, models=models, fLOG=fLOG, runtime=runtime, debug=debug,
verbose, models=models_, fLOG=fLOG, runtime=runtime, debug=debug,
dump_folder=dump_folder, opset_min=opset_min, opset_max=opset_max,
benchmark=benchmark, assume_finite=assume_finite, versions=versions,
extended_list=extended_list,
filter_exp=lambda m, s: str(m) not in skip_models))
return rows

if catch_warnings:
with warnings.catch_warnings():
warnings.simplefilter("ignore",
(UserWarning, ConvergenceWarning,
RuntimeWarning, FutureWarning))
rows = build_rows()
else:
rows = build_rows()
def catch_build_rows(models_):
if catch_warnings:
with warnings.catch_warnings():
warnings.simplefilter("ignore",
(UserWarning, ConvergenceWarning,
RuntimeWarning, FutureWarning))
rows = build_rows(models_)
else:
rows = build_rows(models_)
return rows

rows = catch_build_rows(models)
return _finalize(rows, out_raw, out_summary, verbose, models, fLOG)


def _finalize(rows, out_raw, out_summary, verbose, models, fLOG):
from ..onnxrt.validate import summary_report # pylint: disable=E0402
df = DataFrame(rows)
if os.path.splitext(out_raw)[-1] == ".xlsx":
df.to_excel(out_raw, index=False)
Expand All @@ -129,3 +155,79 @@ def build_rows():
piv.to_csv(out_summary, index=False)
if verbose > 0 and models is not None:
fLOG(piv.T)

# Drops data which cannot be serialized.
for row in rows:
keys = []
for k in row:
if 'lambda' in k:
keys.append(k)
for k in keys:
del row[k]
return rows


def _validate_runtime_dict(kwargs):
return validate_runtime(**kwargs)


def _validate_runtime_separate_process(**kwargs):
models = kwargs['models']
if models in (None, ""):
from ..onnxrt.validate_helper import sklearn_operators
models = [_['name'] for _ in sklearn_operators(extended=True)]
else:
models = models.strip().split(',')

skip_models = kwargs['skip_models']
skip_models = {} if skip_models in (
None, "") else skip_models.strip().split(',')

verbose = kwargs['verbose']
fLOG = kwargs['fLOG']
all_rows = []
skls = [m for m in models if m not in skip_models]
skls.sort()

if verbose > 0:
from tqdm import tqdm
pbar = tqdm(skls)
else:
pbar = skls

for op in pbar:
if not isinstance(pbar, list):
pbar.set_description("[%s]" % (op + " " * (25 - len(op))))

if kwargs['out_raw']:
out_raw = os.path.splitext(kwargs['out_raw'])
out_raw = "".join([out_raw[0], "_", op, out_raw[1]])
else:
out_raw = None

if kwargs['out_summary']:
out_summary = os.path.splitext(kwargs['out_summary'])
out_summary = "".join([out_summary[0], "_", op, out_summary[1]])
else:
out_summary = None

new_kwargs = kwargs.copy()
if 'fLOG' in new_kwargs:
del new_kwargs['fLOG']
new_kwargs['out_raw'] = out_raw
new_kwargs['out_summary'] = out_summary
new_kwargs['models'] = op
new_kwargs['verbose'] = 0 # tqdm fails

p = Pool(1)
try:
lrows = p.map(_validate_runtime_dict, [new_kwargs])
all_rows.extend(lrows[0])
except Exception as e: # pylint: disable=W0703
all_rows.append({
'name': op, 'scenario': 'CRASH',
'ERROR-msg': str(e).replace("\n", " -- ")
})

return _finalize(all_rows, kwargs['out_raw'], kwargs['out_summary'],
verbose, models, fLOG)

0 comments on commit 81536bf

Please sign in to comment.