diff --git a/.circleci/config.yml b/.circleci/config.yml index e44ec247..80b1ecaf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2 jobs: build: docker: - - image: cimg/python:3.10.5 + - image: cimg/python:3.9.5 working_directory: ~/repo diff --git a/.gitignore b/.gitignore index 1cc45db6..367ddcc6 100644 --- a/.gitignore +++ b/.gitignore @@ -61,3 +61,6 @@ _unittests/ut_documentation/summary.csv _unittests/ut_documentation/_test_example.txt _unittests/ut_documentation/_test_example.txt something +_doc/examples/ort_cpu_bind.csv +_doc/examples/ort_cpu_gpu.csv +_doc/examples/ort_cpu.csv diff --git a/_doc/examples/data/ort_cpu_gpu.csv b/_doc/examples/data/ort_cpu_gpu.csv new file mode 100644 index 00000000..562954e5 --- /dev/null +++ b/_doc/examples/data/ort_cpu_gpu.csv @@ -0,0 +1,11 @@ +index,N,n_imgs_seq_cpu,time_seq_cpu,n_imgs_seq_gpu,time_seq_gpu,n_imgs_par,time_par +0,1,2,0.00826602200686466,2,0.5539164490037365,2,0.008887254502042197 +1,3,6,0.019666356995003298,6,0.010879299501539208,6,0.02050348349439446 +2,5,10,0.03761136099637952,10,0.024563621496781707,10,0.025345650006784126 +3,7,14,0.05642429149884265,14,0.03381696599535644,14,0.03412535750248935 +4,9,18,0.061089862501830794,18,0.032227409988990985,18,0.051062139493296854 +5,11,22,0.08963397399929818,22,0.03988744800153654,22,0.049899421486770734 +6,13,26,0.08532479300629348,26,0.059844546995009296,26,0.06177879350434523 +7,15,30,0.099295007501496,30,0.060137088992632926,30,0.07637454000359867 +8,17,34,0.11972474250069354,34,0.08182918949751183,34,0.0649502059968654 +9,19,38,0.1271352384937927,38,0.06829091400140896,38,0.07087059249170125 diff --git a/_doc/examples/data/ort_gpus.csv b/_doc/examples/data/ort_gpus.csv new file mode 100644 index 00000000..1447d1e4 --- /dev/null +++ b/_doc/examples/data/ort_gpus.csv @@ -0,0 +1,11 @@ +index,N,n_imgs_seq_cpu,time_seq_cpu,n_imgs_seq_gpu,time_seq_gpu,n_imgs_par,time_par +0,1,2,0.009441936999792233,2,0.27651189200696535,2,0.4723829315043986 +1,3,6,0.02211003350384999,6,0.010779099495266564,6,0.006075980491004884 +2,5,10,0.03130707699165214,10,0.018668279502890073,10,0.009614631999284029 +3,7,14,0.050966479000635445,14,0.02530089499487076,14,0.013326078507816419 +4,9,18,0.051026727494900115,18,0.03217840299475938,18,0.016715533987735398 +5,11,22,0.06609680749534164,22,0.03952224800013937,22,0.02055577700957656 +6,13,26,0.07877145400561858,26,0.046802844997728243,26,0.023463245990569703 +7,15,30,0.12377040000865236,30,0.05456937898998149,30,0.027040796499932185 +8,17,34,0.09810231548908632,34,0.06142483800067566,34,0.030876389500917867 +9,19,38,0.12459791499713901,38,0.07050566599355079,38,0.034905767504824325 diff --git a/_doc/examples/plot_parallel_execution.py b/_doc/examples/plot_parallel_execution.py new file mode 100644 index 00000000..3a544e83 --- /dev/null +++ b/_doc/examples/plot_parallel_execution.py @@ -0,0 +1,447 @@ +""" +.. _l-plot-parallel-execution: + +=============================== +Multithreading with onnxruntime +=============================== + +.. index:: thread, parallel + +Python implements multithreading but it is not in practice due to the GIL +(see :epkg:`Le GIL`). However, if most of the parallelized code is not creating +python object, this option becomes more interesting than creating several processes +trying to exchange data through sockets. :epkg:`onnxruntime` falls into that category. +For a big model such as a deeplearning model, this might be interesting. +However, :epkg:`onnxruntime` already parallelize the computation of +every operator (Gemm, MatMul) using all the CPU it can get so this approach +should show significant results when used on different processors (CPU, GPU) +in parallel. + +.. contents:: + :local: + +A model +======= + +Let's retrieve a not so big model. +""" +import gc +import multiprocessing +import os +import urllib.request +import threading +import time +import tqdm +import numpy +import pandas +from cpyquickhelper.numbers import measure_time +import torch.cuda +from onnxruntime import InferenceSession, get_all_providers +from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 + SessionIOBinding, OrtDevice as C_OrtDevice) + + +def download_file(url, name, min_size): + if not os.path.exists(name): + print(f"download '{url}'") + with urllib.request.urlopen(url) as u: + content = u.read() + if len(content) < min_size: + raise RuntimeError( + f"Unable to download '{url}' due to\n{content}") + print(f"downloaded {len(content)} bytes.") + with open(name, "wb") as f: + f.write(content) + else: + print(f"'{name}' already downloaded") + + +small = True +if small == "custom": + model_name = "custom.onnx" + url_name = None +elif small: + model_name = "mobilenetv2-10.onnx" + url_name = ("https://github.com/onnx/models/raw/main/vision/" + "classification/mobilenet/model") +else: + model_name = "resnet18-v1-7.onnx" + url_name = ("https://github.com/onnx/models/raw/main/vision/" + "classification/resnet/model") +url_name += "/" + model_name +download_file(url_name, model_name, 100000) + +############################################# +# Measuring inference time when parallelizing on CPU +# ================================================== +# +# Sequence +# ++++++++ +# +# Let's create a random image. + +sess1 = InferenceSession(model_name, providers=["CPUExecutionProvider"]) +for i in sess1.get_inputs(): + print(f"input {i}, name={i.name!r}, type={i.type}, shape={i.shape}") + input_name = i.name + input_shape = list(i.shape) + if input_shape[0] in [None, "batch_size", "N"]: + input_shape[0] = 1 +for i in sess1.get_outputs(): + print(f"output {i}, name={i.name!r}, type={i.type}, shape={i.shape}") + output_name = i.name + + +rnd_img = numpy.random.rand(*input_shape).astype(numpy.float32) + +res = sess1.run(None, {input_name: rnd_img}) +print(f"output: type={res[0].dtype}, shape={res[0].shape}") + +print(measure_time(lambda: sess1.run(None, {input_name: rnd_img}), + div_by_number=True, repeat=10, number=10)) + +############################################# +# Parallelization +# +++++++++++++++ +# +# We define a number of threads lower than the number of cores. + +n_threads = min(8, multiprocessing.cpu_count() - 1) +print(f"n_threads={n_threads}") + + +imgs = [numpy.random.rand(*input_shape).astype(numpy.float32) + for i in range(n_threads)] + +sesss = [InferenceSession(model_name, providers=["CPUExecutionProvider"]) + for i in range(n_threads)] + +################################ +# Let's measure the time for a sequence of images. + + +def sequence(sess, imgs, N=1): + res = [] + for img in imgs: + for i in range(N): + res.append(sess.run(None, {input_name: img})[0]) + return res + + +print(measure_time(lambda: sequence(sesss[0], imgs), + div_by_number=True, repeat=2, number=2)) + +################################# +# And then with multithreading. + + +class MyThread(threading.Thread): + + def __init__(self, sess, imgs): + threading.Thread.__init__(self) + self.sess = sess + self.imgs = imgs + self.q = [] + + def run(self): + for img in self.imgs: + r = self.sess.run(None, {input_name: img})[0] + self.q.append(r) + + +def parallel(sesss, imgs, N=1): + threads = [MyThread(sess, [img] * N) + for sess, img in zip(sesss, imgs)] + for t in threads: + t.start() + res = [] + for t in threads: + t.join() + res.extend(t.q) + return res + + +print(measure_time(lambda: parallel(sesss, imgs), + div_by_number=True, repeat=2, number=2)) + + +################################### +# It is worse for one image. It is expected as mentioned in the introduction. +# Let's check for different number of images to parallelize. + +print("ORT // CPU") +data = [] +rep = 2 +maxN = 21 +for N in tqdm.tqdm(range(1, maxN, 2)): + begin = time.perf_counter() + for i in range(rep): + res1 = sequence(sesss[0], imgs, N) + end = (time.perf_counter() - begin) / rep + obs = dict(N=N, n_imgs_seq=len(res1), time_seq=end) + + begin = time.perf_counter() + for i in range(rep): + res2 = parallel(sesss, imgs, N) + end = (time.perf_counter() - begin) / rep + obs.update(dict(n_imgs_par=len(res2), time_par=end)) + + data.append(obs) + +df = pandas.DataFrame(data) +df.reset_index(drop=False).to_csv("ort_cpu.csv", index=False) +df + +########################################## +# Plots +# +++++ + + +def make_plot(df, title): + + kwargs = dict(title=title, logy=True) + if "time_seq" in df.columns: + df["time_seq_img"] = df["time_seq"] / df["n_imgs_seq"] + df["time_par_img"] = df["time_par"] / df["n_imgs_par"] + columns = ["n_imgs_seq", "time_seq_img", "time_par_img"] + else: + df["time_seq_img_cpu"] = df["time_seq_cpu"] / df["n_imgs_seq_cpu"] + df["time_seq_img_gpu"] = df["time_seq_gpu"] / df["n_imgs_seq_gpu"] + df["time_par_img"] = df["time_par"] / df["n_imgs_par"] + columns = ["n_imgs_seq_cpu", "time_seq_img_cpu", + "time_seq_img_gpu", "time_par_img"] + + ax = df[columns].set_index(columns[0]).plot(**kwargs) + ax.set_xlabel("batch size") + ax.set_ylabel("s") + return ax + + +make_plot(df, "Time per image / batch size") + +####################################### +# As expected, it does not improve. It is like parallezing using +# two strategies, per kernel and per image, both trying to access all +# the process cores at the same time. The time spent to synchronize +# is significant. + +################################################### +# Same with another API based on OrtValue +# +++++++++++++++++++++++++++++++++++++++ +# +# See :epkg:`l-ortvalue-doc`. + + +class MyThreadBind(threading.Thread): + + def __init__(self, sess, imgs, ort_device): + threading.Thread.__init__(self) + self.sess = sess + self.imgs = imgs + self.q = [] + self.bind = SessionIOBinding(self.sess._sess) + self.ort_device = ort_device + + def run(self): + bind = self.bind + ort_device = self.ort_device + bind.bind_output(output_name, ort_device) + sess = self.sess._sess + q = self.q + for img in self.imgs: + bind.bind_input(input_name, ort_device, + img.dtype, img.shape, + img.__array_interface__['data'][0]) + sess.run_with_iobinding(bind, None) + ortvalues = bind.get_outputs() + q.append(ortvalues) + + +def parallel_bind(sesss, imgs, N=1): + ort_device = C_OrtDevice( + C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) + threads = [MyThreadBind(sess, [img] * N, ort_device) + for sess, img in zip(sesss, imgs)] + for t in threads: + t.start() + res = [] + for t in threads: + t.join() + res.extend(t.q) + return res + + +print("ORT (bind) // CPU") +data = [] +for N in tqdm.tqdm(range(1, maxN, 2)): + begin = time.perf_counter() + for i in range(rep): + res1 = sequence(sesss[0], imgs, N) + end = (time.perf_counter() - begin) / rep + obs = dict(N=N, n_imgs_seq=len(res1), time_seq=end) + + begin = time.perf_counter() + for i in range(rep): + res2 = parallel_bind(sesss, imgs, N) + end = (time.perf_counter() - begin) / rep + obs.update(dict(n_imgs_par=len(res2), time_par=end)) + + data.append(obs) + +df = pandas.DataFrame(data) +df.reset_index(drop=False).to_csv("ort_cpu_bind.csv", index=False) +df + +##################################### +# Let's free the memory. + +del sesss[:] +gc.collect() + +############################ +# Plots. + +make_plot(df, "Time per image / batch size\nrun_with_iobinding") + +######################################## +# It leads to the same conclusion. It is no use to parallelize +# on CPU as onnxruntime is already doing that per kernel. + + +######################################## +# GPU +# === +# +# Let's check first if it is possible. + +has_cuda = "CUDAExecutionProvider" in get_all_providers() +if not has_cuda: + print(f"No CUDA provider was detected in {get_all_providers()}.") + +n_gpus = torch.cuda.device_count() if has_cuda else 0 +if n_gpus == 0: + print("No GPU or one GPU was detected.") +elif n_gpus == 1: + print("1 GPU was detected.") +else: + print(f"{n_gpus} GPUs were detected.") + + +######################################### +# Parallelization GPU + CPU +# +++++++++++++++++++++++++ + +if has_cuda and n_gpus > 0: + n_threads = 2 + sesss = [InferenceSession(model_name, providers=["CPUExecutionProvider"]), + InferenceSession(model_name, providers=["CUDAExecutionProvider", + "CPUExecutionProvider"])] + imgs = [numpy.random.rand(*input_shape).astype(numpy.float32) + for i in range(n_threads)] + + print("ORT // CPU + GPU") + data = [] + for N in tqdm.tqdm(range(1, maxN, 2)): + begin = time.perf_counter() + for i in range(rep): + res1 = sequence(sesss[0], imgs, N) + end = (time.perf_counter() - begin) / rep + obs = dict(N=N, n_imgs_seq_cpu=len(res1), time_seq_cpu=end) + + begin = time.perf_counter() + for i in range(rep): + res2 = sequence(sesss[1], imgs, N) + end = (time.perf_counter() - begin) / rep + obs.update(dict(n_imgs_seq_gpu=len(res2), time_seq_gpu=end)) + + begin = time.perf_counter() + for i in range(rep): + res2 = parallel_bind(sesss, imgs, N) + end = (time.perf_counter() - begin) / rep + obs.update(dict(n_imgs_par=len(res2), time_par=end)) + + data.append(obs) + + del sesss[:] + gc.collect() + df = pandas.DataFrame(data) + df.reset_index(drop=False).to_csv("ort_cpu_gpu.csv", index=False) +else: + print("No GPU is available but data should be like the following.") + df = pandas.read_csv("data/ort_cpu_gpu.csv").set_index("N") + +df + +#################################### +# Plots. + +ax = make_plot(df, "Time per image / batch size\nCPU + GPU") +ax + +#################################### +# The parallelization on mulitple CPU + GPUs is working, it is faster than CPU +# but it is still slower than using a single GPU in that case. + +######################################### +# Parallelization on multiple GPUs +# ++++++++++++++++++++++++++++++++ +# +# This is the only case for which it should work as every GPU is indenpendent. + +if n_gpus > 1: + n_threads = 2 + sesss = [] + for i in range(n_gpus): + print(f"Initialize device {i}") + sesss.append( + InferenceSession(model_name, providers=["CUDAExecutionProvider", + "CPUExecutionProvider"], + provider_options=[{"device_id": i}, {}])) + imgs = [numpy.random.rand(*input_shape).astype(numpy.float32) + for i in range(n_threads)] + + print("ORT // GPUs") + data = [] + for N in tqdm.tqdm(range(1, maxN, 2)): + begin = time.perf_counter() + for i in range(rep): + res1 = sequence(sess1, imgs, N) + end = (time.perf_counter() - begin) / rep + obs = dict(N=N, n_imgs_seq_cpu=len(res1), time_seq_cpu=end) + + begin = time.perf_counter() + for i in range(rep): + res2 = sequence(sesss[0], imgs, N) + end = (time.perf_counter() - begin) / rep + obs.update(dict(n_imgs_seq_gpu=len(res2), time_seq_gpu=end)) + + begin = time.perf_counter() + for i in range(rep): + res2 = parallel_bind(sesss, imgs, N) + end = (time.perf_counter() - begin) / rep + obs.update(dict(n_imgs_par=len(res2), time_par=end)) + + data.append(obs) + + del sesss[:] + gc.collect() + df = pandas.DataFrame(data) + df.reset_index(drop=False).to_csv("ort_gpus.csv", index=False) +else: + print("No GPU is available but data should be like the following.") + df = pandas.read_csv("data/ort_gpus.csv").set_index("N") + +df + + +#################################### +# Plots. + +ax = make_plot(df, f"Time per image / batch size\n{n_gpus} GPUs") +ax + +#################################### +# The parallelization on multiple GPUs did work. + +# import matplotlib.pyplot as plt +# plt.show() diff --git a/_doc/sphinxdoc/source/api/onnxruntime_python/index.rst b/_doc/sphinxdoc/source/api/onnxruntime_python/index.rst index 45e6dcbf..e8ade2c9 100644 --- a/_doc/sphinxdoc/source/api/onnxruntime_python/index.rst +++ b/_doc/sphinxdoc/source/api/onnxruntime_python/index.rst @@ -4,7 +4,7 @@ Summary of onnxruntime and onnxruntime-training API Module :epkg:`onnxcustom` leverages :epkg:`onnxruntime-training` to train models. Next sections exposes frequent functions uses to run inference -and training with :epkg:`onnxruntime` and :epkg:`onnxruntume-training`. +and training with :epkg:`onnxruntime` and :epkg:`onnxruntime-training`. Most of the code in :epkg:`onnxruntime` is written in C++ and exposed in Python using :epkg:`pybind11`. For inference, the main class diff --git a/_doc/sphinxdoc/source/conf.py b/_doc/sphinxdoc/source/conf.py index e0f978ef..1f1e7268 100644 --- a/_doc/sphinxdoc/source/conf.py +++ b/_doc/sphinxdoc/source/conf.py @@ -129,6 +129,7 @@ def callback_begin(): 'DLPack': 'https://github.com/dmlc/dlpack', 'docker': 'https://en.wikipedia.org/wiki/Docker_(software)', 'DOT': 'https://www.graphviz.org/doc/info/lang.html', + 'Le GIL': 'http://www.xavierdupre.fr/app/teachpyx/helpsphinx/notebooks/gil_example.html', 'ImageNet': 'http://www.image-net.org/', 'LightGBM': 'https://lightgbm.readthedocs.io/en/latest/', 'lightgbm': 'https://lightgbm.readthedocs.io/en/latest/', diff --git a/_doc/sphinxdoc/source/tutorials/tutorial_onnxruntime/ortvalue_doc.rst b/_doc/sphinxdoc/source/tutorials/tutorial_onnxruntime/ortvalue_doc.rst index 8cca3286..65a1a6db 100644 --- a/_doc/sphinxdoc/source/tutorials/tutorial_onnxruntime/ortvalue_doc.rst +++ b/_doc/sphinxdoc/source/tutorials/tutorial_onnxruntime/ortvalue_doc.rst @@ -1,4 +1,6 @@ +.. _l-ortvalue-doc: + ======== OrtValue ======== diff --git a/_unittests/ut_documentation/test_documentation_examples_lightgbm.py b/_unittests/ut_documentation/test_documentation_examples_lightgbm.py index 279b43f0..ebbe7445 100644 --- a/_unittests/ut_documentation/test_documentation_examples_lightgbm.py +++ b/_unittests/ut_documentation/test_documentation_examples_lightgbm.py @@ -2,7 +2,6 @@ @brief test log(time=60s) """ import unittest -from distutils.version import StrictVersion import os import sys import importlib @@ -10,6 +9,7 @@ from datetime import datetime import onnxruntime from pyquickhelper.pycode import ExtTestCase, skipif_appveyor +from pyquickhelper.texthelper.version_helper import compare_module_version def import_source(module_file_path, module_name): @@ -60,8 +60,8 @@ def test_documentation_examples_lightgbm(self): name)) continue if (name == "plot_pipeline_lightgbm.py" and - StrictVersion(onnxruntime.__version__) < - StrictVersion('1.0.0')): + compare_module_version( + onnxruntime.__version__, '1.0.0') < 0): continue if not name.startswith("plot_") or not name.endswith(".py"): continue diff --git a/onnxcustom/utils/onnx_function.py b/onnxcustom/utils/onnx_function.py index 825b39be..ffa4fd52 100644 --- a/onnxcustom/utils/onnx_function.py +++ b/onnxcustom/utils/onnx_function.py @@ -760,7 +760,14 @@ def _onnx_grad_sigmoid_neg_log_loss_error(target_opset=None, print("DOT-SECTION", oinf.to_dot()) """ - from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE + try: + from onnx.helper import np_dtype_to_tensor_dtype + except ImportError: + from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE + + def np_dtype_to_tensor_dtype(dtype): + return NP_TYPE_TO_TENSOR_TYPE[dtype] + from skl2onnx.algebra.onnx_ops import ( OnnxSub, OnnxMul, OnnxSigmoid, OnnxLog, OnnxNeg, OnnxReduceSum, OnnxReshape, OnnxAdd, OnnxCast, OnnxClip) @@ -771,7 +778,7 @@ def _onnx_grad_sigmoid_neg_log_loss_error(target_opset=None, op_version=target_opset) p0 = OnnxSub(numpy.array([1], dtype=dtype), p1, op_version=target_opset) - y1 = OnnxCast('X1', to=NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)], + y1 = OnnxCast('X1', to=np_dtype_to_tensor_dtype(numpy.dtype(dtype)), op_version=target_opset) y0 = OnnxSub(numpy.array([1], dtype=dtype), y1, op_version=target_opset)