This repository has been archived by the owner on Jan 13, 2024. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes #112, fix number of features for kmeans when validating the run…
…time
- Loading branch information
Showing
10 changed files
with
261 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
@brief test tree node (time=30s) | ||
""" | ||
import os | ||
import unittest | ||
import pandas | ||
from pyquickhelper.loghelper import BufferedPrint | ||
from pyquickhelper.pycode import ExtTestCase, get_temp_folder | ||
from mlprodict.onnxrt.validate.validate_summary import ( | ||
merge_benchmark, summary_report) | ||
from mlprodict.__main__ import main | ||
|
||
|
||
class TestCliValidateRuntime(ExtTestCase): | ||
|
||
def test_cli_validate_kmeans(self): | ||
temp = get_temp_folder(__file__, "temp_validate_runtime_kmeans") | ||
out1 = os.path.join(temp, "raw.csv") | ||
out2 = os.path.join(temp, "sum.csv") | ||
gr = os.path.join(temp, 'gr.png') | ||
st = BufferedPrint() | ||
main(args=["validate_runtime", "--n_features", "4,50", "-nu", "3", | ||
"-re", "3", "-o", "11", "-op", "11", "-v", "2", "--out_raw", | ||
out1, "--out_summary", out2, "-b", "1", | ||
"--runtime", "python_compiled,onnxruntime1", | ||
"--models", "KMeans", "--out_graph", gr, "--dtype", "32"], | ||
fLOG=st.fprint) | ||
res = str(st) | ||
self.assertIn('KMeans', res) | ||
self.assertExists(out1) | ||
self.assertExists(out2) | ||
self.assertExists(gr) | ||
df1 = pandas.read_csv(out1) | ||
merged = merge_benchmark({'r1-': df1, 'r2-': df1.copy()}, | ||
baseline='r1-onnxruntime1') | ||
add_cols = list( | ||
sorted(c for c in merged.columns if c.endswith('-base'))) | ||
suma = summary_report(merged, add_cols=add_cols) | ||
self.assertEqual(merged.shape[0], suma.shape[0]) | ||
self.assertIn('N=10-base', suma.columns) | ||
outdf = os.path.join(temp, "merged.xlsx") | ||
suma.to_excel(outdf, index=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
@brief test log(time=3s) | ||
""" | ||
import unittest | ||
from logging import getLogger | ||
from pandas import DataFrame | ||
from pyquickhelper.loghelper import fLOG | ||
from pyquickhelper.pycode import ExtTestCase | ||
from pyquickhelper.pandashelper import df2rst | ||
from sklearn.exceptions import ConvergenceWarning | ||
try: | ||
from sklearn.utils._testing import ignore_warnings | ||
except ImportError: | ||
from sklearn.utils.testing import ignore_warnings | ||
from skl2onnx import __version__ as skl2onnx_version | ||
from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets, summary_report | ||
from mlprodict.onnxrt.doc.doc_write_helper import split_columns_subsets | ||
|
||
|
||
class TestRtValidateKMeans(ExtTestCase): | ||
|
||
@ignore_warnings(category=(UserWarning, ConvergenceWarning, RuntimeWarning)) | ||
def test_rt_KMeans_python(self): | ||
fLOG(__file__, self._testMethodName, OutputPrint=__name__ == "__main__") | ||
logger = getLogger('skl2onnx') | ||
logger.disabled = True | ||
verbose = 2 if __name__ == "__main__" else 0 | ||
|
||
debug = False | ||
buffer = [] | ||
|
||
def myprint(*args, **kwargs): | ||
buffer.append(" ".join(map(str, args))) | ||
|
||
rows = list(enumerate_validated_operator_opsets( | ||
verbose, models={"KMeans"}, opset_min=11, | ||
opset_max=11, fLOG=myprint, | ||
runtime='python', debug=debug)) | ||
self.assertGreater(len(rows), 1) | ||
self.assertIn('skl_nop', rows[-1]) | ||
keys = set() | ||
for row in rows: | ||
keys.update(set(row)) | ||
self.assertIn('onx_size', keys) | ||
piv = summary_report(DataFrame(rows)) | ||
opset = [c for c in piv.columns if 'opset' in c] | ||
self.assertTrue('opset11' in opset or 'opset10' in opset) | ||
self.assertGreater(len(buffer), 1 if debug else 0) | ||
common, subsets = split_columns_subsets(piv) | ||
try: | ||
conv = df2rst(piv, split_col_common=common, # pylint: disable=E1123 | ||
split_col_subsets=subsets) | ||
self.assertIn('| KMeans |', conv) | ||
except TypeError as e: | ||
if "got an unexpected keyword argument 'split_col_common'" in str(e): | ||
return | ||
raise e | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.