Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
increase code coverage, code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jun 26, 2020
1 parent fd8299e commit e4bb085
Show file tree
Hide file tree
Showing 18 changed files with 278 additions and 215 deletions.
31 changes: 31 additions & 0 deletions _unittests/ut_onnxrt/test_coverage_runtime_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
@brief test log(time=2s)
"""
import unittest
import numpy
from onnx.numpy_helper import from_array
from pyquickhelper.pycode import ExtTestCase
from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
OnnxConstant, OnnxAdd)
from mlprodict.onnxrt import OnnxInference


class TestCoverageRuntimeOps(ExtTestCase):

def test_op_constant(self):
for opv in [9, 10, 11, 12]:
for dtype in [numpy.float32, numpy.float64,
numpy.int32, numpy.int64]:
with self.subTest(opv=opv, dtype=dtype):
X = numpy.array([1], dtype=dtype)
pX = from_array(X)
op = OnnxAdd('X', OnnxConstant(op_version=opv, value=pX),
output_names=['Y'], op_version=opv)
onx = op.to_onnx({'X': X})
oinf = OnnxInference(onx)
res = oinf.run({'X': X})
self.assertEqualArray(res['Y'], X + X)


if __name__ == "__main__":
unittest.main()
22 changes: 13 additions & 9 deletions mlprodict/asv_benchmark/asv_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
@file
@brief Functions to help exporting json format into text.
"""
import pprint
import copy
import os
import json
Expand Down Expand Up @@ -42,7 +43,7 @@ def _coor_to_str(cc):
c = c.replace("'", '"').replace("True", "1").replace("False", "0")
try:
d = json.loads(c)
except JSONDecodeError as e:
except JSONDecodeError as e: # pragma: no cover
raise RuntimeError(
"Unable to interpret '{}'.".format(c)) from e

Expand Down Expand Up @@ -81,9 +82,9 @@ def _figures2dict(metrics, coor, baseline=None):
base_j = i, base.index(quoted_base)
break
if base_j is None:
import pprint
raise ValueError("Unable to find value baseline '{}' or [{}] in {}".format(
baseline, quoted_base, pprint.pformat(coor)))
raise ValueError( # pragma: no cover
"Unable to find value baseline '{}' or [{}] in {}".format(
baseline, quoted_base, pprint.pformat(coor)))
m_bases = {}
ind = [0 for c in coor]
res = {}
Expand Down Expand Up @@ -134,7 +135,8 @@ def enumerate_export_asv_json(folder, as_df=False, last_one=False,
meta_class = None
if conf is not None:
if not os.path.exists(conf):
raise FileNotFoundError("Unable to find '{}'.".format(conf))
raise FileNotFoundError( # pragma: no cover
"Unable to find '{}'.".format(conf))
with open(conf, "r", encoding='utf-8') as f:
meta = json.load(f)
bdir = os.path.join(os.path.dirname(conf), meta['benchmark_dir'])
Expand All @@ -143,7 +145,8 @@ def enumerate_export_asv_json(folder, as_df=False, last_one=False,

bench = os.path.join(folder, 'benchmarks.json')
if not os.path.exists(bench):
raise FileNotFoundError("Unable to find '{}'.".format(bench))
raise FileNotFoundError( # pragma: no cover
"Unable to find '{}'.".format(bench))
with open(bench, 'r', encoding='utf-8') as f:
content = json.load(f)

Expand Down Expand Up @@ -181,10 +184,11 @@ def enumerate_export_asv_json(folder, as_df=False, last_one=False,
results = test_content['results']
for kk, vv in results.items():
if vv is None:
raise RuntimeError('Unexpected empty value for vv')
raise RuntimeError( # pragma: no cover
'Unexpected empty value for vv')
try:
metrics, coord, hash = vv[:3]
except ValueError as e:
except ValueError as e: # pragma: no cover
raise ValueError(
"Test '{}', unable to interpret: {}.".format(
kk, vv)) from e
Expand Down Expand Up @@ -284,7 +288,7 @@ def _enumerate_classes(filename):

try:
exec(cp, gl, loc) # pylint: disable=W0122
except NameError as e:
except NameError as e: # pragma: no cover
raise NameError(
"An import is probably missing from function 'fix_missing_imports'"
".") from e
Expand Down
25 changes: 15 additions & 10 deletions mlprodict/asv_benchmark/common_asv_skl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def _get_xdtype(self, dtype):
return numpy.float32
elif dtype in ('double', numpy.float64):
return numpy.float64
raise ValueError("Unknown dtype '{}'.".format(dtype))
raise ValueError( # pragma: no cover
"Unknown dtype '{}'.".format(dtype))

def _get_dataset(self, nf, dtype):
xdtype = self._get_xdtype(dtype)
Expand All @@ -98,7 +99,7 @@ def _to_onnx(self, model, X, opset, dtype, optim):
if optim is None or len(optim) == 0:
options = self.par_convopts
elif self.par_convopts and len(self.par_convopts) > 0:
raise NotImplementedError(
raise NotImplementedError( # pragma: no cover
"Conflict between par_convopts={} and optim={}".format(
self.par_convopts, optim))
else:
Expand All @@ -120,7 +121,7 @@ def _create_onnx_inference(self, onx, runtime):

try:
res = OnnxInference(onx, runtime=runtime)
except RuntimeError as e:
except RuntimeError as e: # pragma: no cover
if "[ONNXRuntimeError]" in str(e):
return RuntimeError("onnxruntime fails due to {}".format(str(e)))
raise e
Expand All @@ -145,7 +146,8 @@ def runtime_name(self, runtime):
elif runtime == 'pyrtc':
name = 'python_compiled'
else:
raise ValueError("Unknown runtime '{}'.".format(runtime))
raise ValueError( # pragma: no cover
"Unknown runtime '{}'.".format(runtime))
return name

def _name(self, nf, opset, dtype):
Expand All @@ -168,8 +170,9 @@ def setup_cache(self):
with open(filename, "wb") as f:
pickle.dump(stored, f)
if not os.path.exists(filename):
raise RuntimeError("Unable to dump model %r into %r." % (
model, filename))
raise RuntimeError( # pragma: no cover
"Unable to dump model %r into %r." % (
model, filename))

def setup(self, runtime, N, nf, opset, dtype, optim):
"asv API"
Expand Down Expand Up @@ -227,16 +230,18 @@ def track_vort(self, runtime, N, nf, opset, dtype, optim):
try:
from onnxruntime import __version__
return version2number(__version__)
except ImportError:
except ImportError: # pragma: no cover
return 0

def check_method_name(self, method_name):
"Does some verifications. Fails if inconsistencies."
if getattr(self, 'chk_method_name', None) not in (None, method_name):
raise RuntimeError("Method name must be '{}'.".format(method_name))
raise RuntimeError( # pragma: no cover
"Method name must be '{}'.".format(method_name))
if getattr(self, 'chk_method_name', None) is None:
raise RuntimeError(
"Unable to check that the method name is correct (expected is '{}')".format(
raise RuntimeError( # pragma: no cover
"Unable to check that the method name is correct "
"(expected is '{}')".format(
method_name))


Expand Down
12 changes: 6 additions & 6 deletions mlprodict/onnx_conv/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas
try:
from sklearn.metrics._scorer import _PredictScorer
except ImportError:
except ImportError: # pragma: no cover
# scikit-learn < 0.22
from sklearn.metrics.scorer import _PredictScorer
from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version
Expand Down Expand Up @@ -161,7 +161,7 @@ def guess_schema_from_model(model, tensor_type=None, schema=None):
if schema is not None:
try:
guessed = guess_schema_from_model(model)
except NotImplementedError:
except NotImplementedError: # pragma: no cover
return _replace_tensor_type(schema, tensor_type)
if len(guessed) != len(schema):
raise RuntimeError(
Expand All @@ -183,15 +183,15 @@ def guess_schema_from_model(model, tensor_type=None, schema=None):
import pprint
data = pprint.pformat(model.__dict__)
dirs = pprint.pformat(dir(model))
if hasattr(model, 'dump_model'):
if hasattr(model, 'dump_model'): # pragma: no cover
dumped = model.dump_model()
keys = list(sorted(dumped))
last = pprint.pformat([keys, dumped])
if len(last) >= 200000:
last = last[:200000] + "\n..."
else:
last = ""
raise NotImplementedError(
raise NotImplementedError( # pragma: no cover
"Unable to guess schema for model {}\n{}\n----\n{}\n------\n{}".format(
model.__class__, data, dirs, last))

Expand Down Expand Up @@ -307,7 +307,7 @@ def to_onnx(model, X=None, name=None, initial_types=None,
"""
if isinstance(model, OnnxOperatorMixin):
if not hasattr(model, 'op_version'):
raise RuntimeError(
raise RuntimeError( # pragma: no cover
"Missing attribute 'op_version' for type '{}'.".format(
type(model)))
return model.to_onnx(X=X, name=name, dtype=dtype,
Expand All @@ -334,7 +334,7 @@ def _guess_type_(X, itype, dtype):
new_dtype = numpy.float32
if new_dtype not in (numpy.float32, numpy.float64, numpy.int64,
numpy.int32):
raise NotImplementedError(
raise NotImplementedError( # pragma: no cover
"dtype should be real not {} ({})".format(new_dtype, dtype))
return initial_types, dtype, new_dtype

Expand Down
2 changes: 1 addition & 1 deletion mlprodict/onnx_conv/operator_converters/conv_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def calculate_lightgbm_output_shapes(operator):
return calculate_linear_classifier_output_shapes(operator)
if objective.startswith('regression'):
return calculate_linear_regressor_output_shapes(operator)
raise NotImplementedError(
raise NotImplementedError( # pragma: no cover
"Objective '{}' is not implemented yet.".format(objective))


Expand Down
Loading

0 comments on commit e4bb085

Please sign in to comment.