Skip to content

Commit

Permalink
split unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jan 25, 2022
1 parent 38db7cb commit 80d8d84
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TestDocumentationExampleBenchmark(ExtTestCase):
ortt is None, reason="onnxruntime-training not installed.")
@skipif_circleci("stuck")
@skipif_appveyor("too long")
def test_documentation_examples_training(self):
def test_documentation_examples(self):

this = os.path.abspath(os.path.dirname(__file__))
onxc = os.path.normpath(os.path.join(this, '..', '..'))
Expand All @@ -55,6 +55,8 @@ def test_documentation_examples_training(self):
for name in sorted(found):
if 'benchmark' not in name:
continue
if 'orttraining' in name:
continue
if not name.startswith("plot_") or not name.endswith(".py"):
continue

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
@brief test log(time=60s)
"""
import unittest
import os
import sys
import importlib
import subprocess
from datetime import datetime
try:
import onnxruntime.capi.ort_trainer as ortt
except ImportError:
ortt = None
from pyquickhelper.pycode import skipif_circleci, ExtTestCase, skipif_appveyor
from pyquickhelper.texthelper import compare_module_version
from mlprodict import __version__ as mlp_version


def import_source(module_file_path, module_name):
if not os.path.exists(module_file_path):
raise FileNotFoundError(module_file_path)
module_spec = importlib.util.spec_from_file_location(
module_name, module_file_path)
if module_spec is None:
raise FileNotFoundError(
"Unable to find '{}' in '{}', cwd='{}'.".format(
module_name, module_file_path,
os.path.abspath(__file__)))
module = importlib.util.module_from_spec(module_spec)
return module_spec.loader.exec_module(module)


class TestDocumentationExampleBenchmarkTraining(ExtTestCase):

@unittest.skipIf(
compare_module_version(mlp_version, "0.7.1642") <= 0,
reason="plot_onnx was updated.")
@unittest.skipIf(
ortt is None, reason="onnxruntime-training not installed.")
@skipif_circleci("stuck")
@skipif_appveyor("too long")
def test_documentation_examples_training(self):

this = os.path.abspath(os.path.dirname(__file__))
onxc = os.path.normpath(os.path.join(this, '..', '..'))
pypath = os.environ.get('PYTHONPATH', None)
sep = ";" if sys.platform == 'win32' else ':'
pypath = "" if pypath in (None, "") else (pypath + sep)
pypath += onxc
os.environ['PYTHONPATH'] = pypath
fold = os.path.normpath(
os.path.join(this, '..', '..', '_doc', 'examples'))
found = os.listdir(fold)
tested = 0
for name in sorted(found):
if 'benchmark' not in name:
continue
if 'orttraining' not in name:
continue
if not name.startswith("plot_") or not name.endswith(".py"):
continue

with self.subTest(name=name):
if __name__ == "__main__" or "-v" in sys.argv:
print("%s: run %r" % (
datetime.now().strftime("%d-%m-%y %H:%M:%S"),
name))
sys.path.insert(0, fold)
try:
mod = import_source(fold, os.path.splitext(name)[0])
assert mod is not None
except FileNotFoundError:
# try another way
cmds = [sys.executable, "-u",
os.path.join(fold, name)]
p = subprocess.Popen(
cmds, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
res = p.communicate()
_, err = res
st = err.decode('ascii', errors='ignore')
if len(st) > 0 and 'Traceback' in st:
if "No such file or directory: 'dot': 'dot'" in st:
# dot not installed, this part
# is tested in onnx framework
pass
elif '"dot" not found in path.' in st:
# dot not installed, this part
# is tested in onnx framework
pass
elif ('Please fix either the inputs or '
'the model.') in st:
# onnxruntime datasets changed in master
# branch, still the same in released
# version on pypi
pass
elif 'dot: graph is too large' in st:
# graph is too big
pass
else:
raise RuntimeError( # pylint: disable=W0707
"Example '{}' (cmd: {} - exec_prefix="
"'{}') failed due to\n{}"
"".format(name, cmds, sys.exec_prefix, st))
finally:
if sys.path[0] == fold:
del sys.path[0]
with open(
os.path.join(os.path.dirname(__file__),
"_test_example.txt"), "a",
encoding='utf-8') as f:
f.write(name + "\n")
tested += 1
if tested == 0:
raise RuntimeError("No example was tested.")


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions _unittests/ut_plotting/test_plotting_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def setUpClass(cls):
logger = logging.getLogger('skl2onnx')
logger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)
logging.getLogger('matplotlib.font_manager').disabled = True

@skipif_travis('graphviz is not installed')
@skipif_circleci('graphviz is not installed')
Expand Down
2 changes: 2 additions & 0 deletions _unittests/ut_training/test_optimizers_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class TestOptimizersClassification(ExtTestCase):
def setUpClass(cls):
logger = logging.getLogger('skl2onnx')
logger.setLevel(logging.WARNING)
logger = logging.getLogger('onnxcustom')
logger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)

@unittest.skipIf(TrainingSession is None, reason="not training")
Expand Down
10 changes: 10 additions & 0 deletions _unittests/ut_training/test_optimizers_forward_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class TestOptimizersForwardBackward(ExtTestCase):
def setUpClass(cls):
logger = logging.getLogger('skl2onnx')
logger.setLevel(logging.WARNING)
logger = logging.getLogger('onnxcustom')
logger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)

@classmethod
def tearDownClass(cls):
logger = logging.getLogger('skl2onnx')
logger.setLevel(logging.WARNING)
logger = logging.getLogger('onnxcustom')
logger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)

@unittest.skipIf(TrainingSession is None, reason="not training")
Expand Down
16 changes: 16 additions & 0 deletions _unittests/ut_training/test_orttraining_forward_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@

class TestOrtTrainingForwardBackward(ExtTestCase):

@classmethod
def setUpClass(cls):
logger = logging.getLogger('skl2onnx')
logger.setLevel(logging.WARNING)
logger = logging.getLogger('onnxcustom')
logger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)

@classmethod
def tearDownClass(cls):
logger = logging.getLogger('skl2onnx')
logger.setLevel(logging.WARNING)
logger = logging.getLogger('onnxcustom')
logger.setLevel(logging.WARNING)
logging.basicConfig(level=logging.WARNING)

def forward_no_training(self, exc=None, verbose=False):
if exc is None:
exc = __name__ != '__main__'
Expand Down

0 comments on commit 80d8d84

Please sign in to comment.