Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions _doc/examples/plot_orttraining_benchmark_fwbw_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
"""

.. _l-orttraining-benchmark-fwbw-cls:

Benchmark, comparison sklearn - forward-backward - classification
=================================================================

The benchmark compares the processing time between :epkg:`scikit-learn`
and :epkg:`onnxruntime-training` on a logistic regression regression
and a neural network for classification.
It replicates the benchmark implemented in :ref:`l-orttraining-benchmark-fwbw`.

.. contents::
:local:

First comparison: neural network
++++++++++++++++++++++++++++++++

"""
import warnings
import time
import numpy
import matplotlib.pyplot as plt
from pandas import DataFrame
from onnxruntime import get_device
from pyquickhelper.pycode.profiling import profile, profile2graph
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from mlprodict.onnx_conv import to_onnx
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs
from onnxcustom.utils.onnx_helper import onnx_rename_weights
from onnxcustom.training.optimizers_partial import (
OrtGradientForwardBackwardOptimizer)
from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov
from onnxcustom.training.sgd_learning_loss import NegLogLearningLoss
from onnxcustom.training.sgd_learning_penalty import ElasticLearningPenalty


X, y = make_classification(1000, n_features=100, n_classes=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.int64)
X_train, X_test, y_train, y_test = train_test_split(X, y)

########################################
# Benchmark function.


def benchmark(X, y, skl_model, train_session, name, verbose=True):
"""
:param skl_model: model from scikit-learn
:param train_session: instance of OrtGradientForwardBackwardOptimizer
:param name: experiment name
:param verbose: to debug
"""
print("[benchmark] %s" % name)
begin = time.perf_counter()
skl_model.fit(X, y)
duration_skl = time.perf_counter() - begin
length_skl = len(skl_model.loss_curve_)
print("[benchmark] skl=%r iterations - %r seconds" % (
length_skl, duration_skl))

begin = time.perf_counter()
train_session.fit(X, y)
duration_ort = time.perf_counter() - begin
length_ort = len(train_session.train_losses_)
print("[benchmark] ort=%r iteration - %r seconds" % (
length_ort, duration_ort))

return dict(skl=duration_skl, ort=duration_ort, name=name,
iter_skl=length_skl, iter_ort=length_ort,
losses_skl=skl_model.loss_curve_,
losses_ort=train_session.train_losses_)


########################################
# Common parameters and model

batch_size = 15
max_iter = 100

nn = MLPClassifier(hidden_layer_sizes=(50, 10), max_iter=max_iter,
solver='sgd', learning_rate_init=1e-1, alpha=1e-4,
n_iter_no_change=max_iter * 3, batch_size=batch_size,
nesterovs_momentum=True, momentum=0.9,
learning_rate="invscaling")

with warnings.catch_warnings():
warnings.simplefilter('ignore')
nn.fit(X_train, y_train)

########################################
# Conversion to ONNX and trainer initialization
# It is slightly different from a regression model.
# Probabilities usually come from raw scores transformed
# through a function such as the sigmoid function.
# The gradient of the loss is computed against the raw scores
# because it is easier to compute than to let onnxruntime
# do it.

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15,
options={'zipmap': False, 'nocl': True})
print(onnx_simple_text_plot(onx))

##########################################
# Raw scores are the input of operator *Sigmoid*.

onx = select_model_inputs_outputs(
onx, outputs=["add_result2"], infer_shapes=True)
print(onnx_simple_text_plot(onx))

#########################################
# And the names are renamed to have them follow the
# alphabetical order (see :class:`OrtGradientForwardBackward
# <onnxcustom.training.ortgradient.OrtGradientForwardBackward>`).

onx = onnx_rename_weights(onx)
print(onnx_simple_text_plot(onx))

################################################
# We select the log loss (see :class:`NegLogLearningLoss
# <from onnxcustom.training.sgd_learning_loss.NegLogLearningLoss>`,
# a simple penalty defined with :class:`ElasticLearningPenalty
# <onnxcustom.training.sgd_learning_penalty.ElasticLearningPenalty>`,
# and the Nesterov algorithm to update the weights with
# `LearningRateSGDNesterov
# <onnxcustom.training.sgd_learning_rate.LearningRateSGDNesterov>`.

train_session = OrtGradientForwardBackwardOptimizer(
onx, device='cpu', warm_start=False,
max_iter=max_iter, batch_size=batch_size,
learning_loss=NegLogLearningLoss(),
learning_rate=LearningRateSGDNesterov(
1e-7, nesterov=True, momentum=0.9),
learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4))


benches = [benchmark(X_train, y_train, nn, train_session, name='NN-CPU')]

######################################
# Profiling
# +++++++++


def clean_name(text):
pos = text.find('onnxruntime')
if pos >= 0:
return text[pos:]
pos = text.find('sklearn')
if pos >= 0:
return text[pos:]
pos = text.find('onnxcustom')
if pos >= 0:
return text[pos:]
pos = text.find('site-packages')
if pos >= 0:
return text[pos:]
return text


ps = profile(lambda: benchmark(X_train, y_train,
nn, train_session, name='NN-CPU'))[0]
root, nodes = profile2graph(ps, clean_text=clean_name)
text = root.to_text()
print(text)

######################################
# if GPU is available
# +++++++++++++++++++

if get_device().upper() == 'GPU':

train_session = OrtGradientForwardBackwardOptimizer(
onx, device='cuda', warm_start=False,
max_iter=max_iter, batch_size=batch_size,
learning_loss=NegLogLearningLoss(),
learning_rate=LearningRateSGDNesterov(
1e-7, nesterov=False, momentum=0.9),
learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4))

benches.append(benchmark(X_train, y_train, nn,
train_session, name='NN-GPU'))

#######################################
# A simple linear layer
# +++++++++++++++++++++

nn = MLPClassifier(hidden_layer_sizes=tuple(), max_iter=max_iter,
solver='sgd', learning_rate_init=1e-1, alpha=1e-4,
n_iter_no_change=max_iter * 3, batch_size=batch_size,
nesterovs_momentum=True, momentum=0.9,
learning_rate="invscaling", activation='identity')


with warnings.catch_warnings():
warnings.simplefilter('ignore')
nn.fit(X_train, y_train)

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15,
options={'zipmap': False, 'nocl': True})
print(onnx_simple_text_plot(onx))

onx = select_model_inputs_outputs(
onx, outputs=["add_result"], infer_shapes=True)
print(onnx_simple_text_plot(onx))

onx = onnx_rename_weights(onx)
print(onnx_simple_text_plot(onx))

train_session = OrtGradientForwardBackwardOptimizer(
onx, device='cpu', warm_start=False,
max_iter=max_iter, batch_size=batch_size,
learning_loss=NegLogLearningLoss(),
learning_rate=LearningRateSGDNesterov(
1e-5, nesterov=True, momentum=0.9),
learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4))


benches.append(benchmark(X_train, y_train, nn, train_session, name='LR-CPU'))

if get_device().upper() == 'GPU':

train_session = OrtGradientForwardBackwardOptimizer(
onx, device='cuda', warm_start=False,
max_iter=max_iter, batch_size=batch_size,
learning_loss=NegLogLearningLoss(),
learning_rate=LearningRateSGDNesterov(
1e-5, nesterov=False, momentum=0.9),
learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4))

benches.append(benchmark(X_train, y_train, nn,
train_session, name='LR-GPU'))


######################################
# Graphs
# ++++++
#
# Dataframe first.

df = DataFrame(benches).set_index('name')
df

#######################################
# text output

print(df)

#######################################
# Graphs.

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
df[['skl', 'ort']].plot.bar(title="Processing time", ax=ax[0])
ax[0].tick_params(axis='x', rotation=30)
for bench in benches:
ax[1].plot(bench['losses_skl'][1:], label='skl-' + bench['name'])
ax[1].plot(bench['losses_ort'][1:], label='ort-' + bench['name'])
ax[1].set_yscale('log')
ax[1].set_title("Losses")
ax[1].legend()

########################################
# The gradient update are not exactly the same.
# It should be improved for a fair comprison.

# plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,23 @@ used by the three components mentioned above. They cache the binded pointers
if the method is called again with a different `OrtValue` but a same pointer
returned by `data_ptr()`.

Binary classification
+++++++++++++++++++++

Probabilities are computed from raw scores with a function such as the
`sigmoid function <https://en.wikipedia.org/wiki/Sigmoid_function>`_.
A binary function produces two probilities: :math:`sigmoid(s)`
:math:`(1 - sigmoid(s))` where *s* is the raw score. The associated loss
function is usually the log loss: :math:`loss(y, X) =
(1-y) \log(1-p(s)) + y \log p(s)` where *y* is the expected class (0 or 1),
*s=s(X)* is the raw score, *p(s)* is the probability.
We could compute the gradient of the loss
against the probability and let :epkg:`onnxruntime-training` handle the
computation of the gradient from the probability to the input.
However, the gradient of the loss against the raw score can easily be
expressed as :math:`grad(loss(y, s)) = y - p(s)`. The second
option is implemented in example :ref:`l-orttraining-benchmark-fwbw-cls`.

Examples
++++++++

Expand All @@ -143,3 +160,4 @@ with ONNX and :epkg:`onnxruntime-training`.
../gyexamples/plot_orttraining_nn_gpu_fwbw
../gyexamples/plot_orttraining_nn_gpu_fwbw_nesterov
../gyexamples/plot_orttraining_benchmark_fwbw
../gyexamples/plot_orttraining_benchmark_fwbw_cls
30 changes: 24 additions & 6 deletions onnxcustom/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,30 @@ def local_bind(bind, offset, n):
shape_X = (n, n_col_x)
shape_y = (n, n_col_y)

bind.bind_input(
names[0], self.device, self.desc[0][1], shape_X,
self.X_ort.data_ptr() + offset * n_col_x * size_x)
bind.bind_input(
names[1], self.device, self.desc[1][1], shape_y,
self.y_ort.data_ptr() + offset * n_col_y * size_y)
try:
bind.bind_input(
names[0], self.device, self.desc[0][1], shape_X,
self.X_ort.data_ptr() + offset * n_col_x * size_x)
except RuntimeError as e:
raise RuntimeError(
"Unable to bind data input (X) %r, device=%r desc=%r "
"data_ptr=%r offset=%r n_col_x=%r size_x=%r "
"type(bind)=%r" % (
names[0], self.device, self.desc[0][1],
self.X_ort.data_ptr(), offset, n_col_x, size_x,
type(bind))) from e
try:
bind.bind_input(
names[1], self.device, self.desc[1][1], shape_y,
self.y_ort.data_ptr() + offset * n_col_y * size_y)
except RuntimeError as e:
raise RuntimeError(
"Unable to bind data input (y) %r, device=%r desc=%r "
"data_ptr=%r offset=%r n_col_y=%r size_y=%r "
"type(bind)=%r" % (
names[1], self.device, self.desc[1][1],
self.y_ort.data_ptr(), offset, n_col_y, size_y,
type(bind))) from e

def local_bindw(bind, offset, n):
# This function assumes the data is contiguous.
Expand Down
2 changes: 1 addition & 1 deletion onnxcustom/utils/orttraining_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def _replace(ens):
elem = output_onx.type.tensor_type.elem_type
if elem == 0:
raise TypeError( # pragma: no cover
"Unable to guess inut tensor type from %r."
"Unable to guess input tensor type from %r."
"" % output_onx)
shape = []
for d in output_onx.type.tensor_type.shape.dim:
Expand Down