Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ examples/onnxruntime_profile*.json
version.txt
_doc/bench/*.svg
_doc/examples/*.svg
_doc/bench/*.html
_doc/examples/*.html
_doc/examples/*.json
_doc/sphinxdoc/source/phdoc_static/*.js
_doc/sphinxdoc/source/phdoc_static/reveal.js/*
_doc/sphinxdoc/source/phdoc_static/style_notebook_snippet.css
218 changes: 218 additions & 0 deletions _doc/bench/bench_ortmodule_nn_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""

.. _l-orttraining-nn-benchmark:

Benchmark ORTModule on a neural network
=======================================

To make it work, you may need to run:

::

python -c "from onnxruntime.training.ortmodule.torch_cpp_extensions import install as ortmodule_install;ortmodule_install.build_torch_cpp_extensions()"

You may profile the full example with on CPU with :epkg:`py-spy`:

::

py-spy record -o bench_ortmodule_nn_gpu.svg -r 10 --native -- python bench_ortmodule_nn_gpu.py
py-spy record -o bench_ortmodule_nn_gpu.svg -r 20 -- python bench_ortmodule_nn_gpu.py --n_features 100 --hidden_layer_sizes "30,30"

The python can be profiled with :epkg:`pyinstrument`.

::

python -m pyinstrument --show-all -r html -o bench_ortmodule_nn_gpu.html bench_ortmodule_nn_gpu.py --n_features 100 --hidden_layer_sizes "30,30"

And with `nvprof` on GPU:

::

nvprof -o bench_ortmodule_nn_gpu.nvprof python bench_ortmodule_nn_gpu.py --run_torch 0 --device cuda --opset 14

.. contents::
:local:

A neural network with scikit-learn
++++++++++++++++++++++++++++++++++

"""
import warnings
from pprint import pprint
import time
import numpy
from pandas import DataFrame
from onnxruntime import get_device
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from onnxruntime.training import ORTModule


def benchmark(N=1000, n_features=20, hidden_layer_sizes="26,25", max_iter=1000,
learning_rate_init=1e-4, batch_size=100, run_torch=True,
device='cpu', opset=12):
"""
Compares :epkg:`onnxruntime-training` to :epkg:`scikit-learn` for
training. Training algorithm is SGD.

:param N: number of observations to train on
:param n_features: number of features
:param hidden_layer_sizes: hidden layer sizes, comma separated values
:param max_iter: number of iterations
:param learning_rate_init: initial learning rate
:param batch_size: batch size
:param run_torch: train scikit-learn in the same condition (True) or
just walk through one iterator with *scikit-learn*
:param device: `'cpu'` or `'cuda'`
:param opset: opset to choose for the conversion
"""
N = int(N)
n_features = int(n_features)
max_iter = int(max_iter)
learning_rate_init = float(learning_rate_init)
batch_size = int(batch_size)
run_torch = run_torch in (1, True, '1', 'True')

print("N=%d" % N)
print("n_features=%d" % n_features)
print("hidden_layer_sizes=%r" % (hidden_layer_sizes, ))
print("max_iter=%d" % max_iter)
print("learning_rate_init=%f" % learning_rate_init)
print("batch_size=%d" % batch_size)
print("run_torch=%r" % run_torch)
print("opset=%r (unused)" % opset)
print("device=%r" % device)
device0 = device
device = torch.device(
"cuda:0" if device in ('cuda', 'cuda:0', 'gpu') else "cpu")
print("fixed device=%r" % device)
print('------------------')

if not isinstance(hidden_layer_sizes, tuple):
hidden_layer_sizes = tuple(map(int, hidden_layer_sizes.split(",")))
X, y = make_regression(N, n_features=n_features, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)


class Net(torch.nn.Module):
def __init__(self, n_features, hidden, n_output):
super(Net, self).__init__()
self.hidden = []

size = n_features
for i, hid in enumerate(hidden_layer_sizes):
self.hidden.append(torch.nn.Linear(size, hid))
size = hid
setattr(self, "hid%d" % i, self.hidden[-1])
self.hidden.append(torch.nn.Linear(size, n_output))
setattr(self, "predict", self.hidden[-1])

def forward(self, x):
for hid in self.hidden:
x = hid(x)
x = F.relu(x)
return x

nn = Net(n_features, hidden_layer_sizes, 1)
if device0 == 'cpu':
nn.cpu()
else:
nn.cuda(device=device)
print("n_parameters=%d, n_layers=%d" % (
len(list(nn.parameters())), len(nn.hidden)))
for i, p in enumerate(nn.parameters()):
print(" p[%d].shape=%r" % (i, p.shape))

optimizer = torch.optim.SGD(nn.parameters(), lr=learning_rate_init)
criterion = torch.nn.MSELoss(size_average=False)
batch_no = len(X_train) // batch_size

# training

def train_torch():
for epoch in range(max_iter):
running_loss = 0.0
x, y = shuffle(X_train, y_train)
for i in range(batch_no):
start = i * batch_size
end = start + batch_size
inputs = torch.tensor(
x[start:end], requires_grad=True, device=device)
labels = torch.tensor(
y[start:end], requires_grad=True, device=device)

def step_torch():
optimizer.zero_grad()
outputs = nn(inputs)
loss = criterion(outputs, torch.unsqueeze(labels, dim=1))
loss.backward()
optimizer.step()
return loss

loss = step_torch()
running_loss += loss.item()
return running_loss

begin = time.perf_counter()
if run_torch:
running_loss = train_torch()
dur_torch = time.perf_counter() - begin

if run_torch:
print("time_torch=%r, running_loss=%r" % (dur_torch, running_loss))
running_loss0 = running_loss
else:
running_loss0 = -1

# ORTModule
nn = Net(n_features, hidden_layer_sizes, 1)
if device0 == 'cpu':
nn.cpu()
else:
nn.cuda(device=device)

nn_ort = ORTModule(nn)
optimizer = torch.optim.SGD(nn_ort.parameters(), lr=learning_rate_init)
criterion = torch.nn.MSELoss(size_average=False)

def train_ort():
for epoch in range(max_iter):
running_loss = 0.0
x, y = shuffle(X_train, y_train)
for i in range(batch_no):
start = i * batch_size
end = start + batch_size
inputs = torch.tensor(
x[start:end], requires_grad=True, device=device)
labels = torch.tensor(
y[start:end], requires_grad=True, device=device)

def step_ort():
optimizer.zero_grad()
outputs = nn_ort(inputs)
loss = criterion(outputs, torch.unsqueeze(labels, dim=1))
loss.backward()
optimizer.step()
return loss

loss = step_ort()
running_loss += loss.item()
return running_loss

begin = time.perf_counter()
running_loss = train_ort()
dur_ort = time.perf_counter() - begin

print("time_torch=%r, running_loss=%r" % (dur_torch, running_loss0))
print("time_ort=%r, last_trained_error=%r" % (dur_ort, running_loss))


if __name__ == "__main__":
import fire
fire.Fire(benchmark)
9 changes: 5 additions & 4 deletions _doc/bench/bench_orttraining_nn_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def benchmark(N=1000, n_features=20, hidden_layer_sizes="25,25", max_iter=1000,

print("N=%d" % N)
print("n_features=%d" % n_features)
print("hidden_layer_sizes=%s" % hidden_layer_sizes)
print("hidden_layer_sizes=%r" % (hidden_layer_sizes, ))
print("max_iter=%d" % max_iter)
print("learning_rate_init=%f" % learning_rate_init)
print("batch_size=%d" % batch_size)
Expand All @@ -75,7 +75,8 @@ def benchmark(N=1000, n_features=20, hidden_layer_sizes="25,25", max_iter=1000,
print("device=%r" % device)
print('------------------')

hidden_layer_sizes = tuple(map(int, hidden_layer_sizes.split(",")))
if not isinstance(hidden_layer_sizes, tuple):
hidden_layer_sizes = tuple(map(int, hidden_layer_sizes.split(",")))
X, y = make_regression(N, n_features=n_features, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
Expand All @@ -92,7 +93,7 @@ def benchmark(N=1000, n_features=20, hidden_layer_sizes="25,25", max_iter=1000,
nn.fit(X_train, y_train)
dur_skl = time.perf_counter() - begin

print("time_kl=%r, mean_squared_error=%r" % (
print("time_skl=%r, mean_squared_error=%r" % (
dur_skl, mean_squared_error(y_train, nn.predict(X_train))))

# conversion to ONNX
Expand All @@ -119,7 +120,7 @@ def benchmark(N=1000, n_features=20, hidden_layer_sizes="25,25", max_iter=1000,
begin = time.perf_counter()
train_session.fit(X, y)
dur_ort = time.perf_counter() - begin
print("time_kl=%r, mean_squared_error=%r" % (
print("time_skl=%r, mean_squared_error=%r" % (
dur_skl, mean_squared_error(y_train, nn.predict(X_train))))
print("time_ort=%r, last_trained_error=%r" % (
dur_ort, train_session.train_losses_[-1]))
Expand Down
24 changes: 0 additions & 24 deletions _doc/sphinxdoc/source/api.rst

This file was deleted.

5 changes: 5 additions & 0 deletions _doc/sphinxdoc/source/api/data_loader.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

DataLoder
=========

.. autoclass:: onnxcustom.training.data_loader.OrtDataLoader
10 changes: 10 additions & 0 deletions _doc/sphinxdoc/source/api/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

===
API
===

.. toctree::

utils
data_loader
training
23 changes: 23 additions & 0 deletions _doc/sphinxdoc/source/api/training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

Traning
=======

.. contents::
:local:

BaseEstimator
+++++++++++++

.. autofunction:: onnxcustom.training.optimizers.BaseEstimator

OrtGradientOptimizer
++++++++++++++++++++

.. autofunction:: onnxcustom.training.optimizers.OrtGradientOptimizer

Helpers
+++++++

.. autofunction:: onnxcustom.training.orttraining.add_loss_output

.. autofunction:: onnxcustom.training.orttraining.get_train_initializer
5 changes: 5 additions & 0 deletions _doc/sphinxdoc/source/api/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

Utils
=====

.. autofunction:: onnxcustom.utils.measure_time
6 changes: 6 additions & 0 deletions _doc/sphinxdoc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
issue=('https://github.com/sdpython/onnxcustom/issues/%s', 'issue')),
title="onnxcustom", book=True)

extensions.append([
"sphinxcontrib.blockdiag"
])

blog_root = "http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/"

html_css_files = ['my-styles.css']
Expand Down Expand Up @@ -88,6 +92,7 @@
'onnxruntime-training':
'https://github.com/microsoft/onnxruntime/tree/master/orttraining',
'openmp': 'https://en.wikipedia.org/wiki/OpenMP',
'py-spy': 'https://github.com/benfred/py-spy',
'pyinstrument': 'https://github.com/joerick/pyinstrument',
'python': 'https://www.python.org/',
'pytorch': 'https://pytorch.org/',
Expand All @@ -97,6 +102,7 @@
'sphinx-gallery': 'https://github.com/sphinx-gallery/sphinx-gallery',
'Stochastic Gradient Descent':
'https://en.wikipedia.org/wiki/Stochastic_gradient_descent',
'tqdm': 'https://github.com/tqdm/tqdm',
'TreeEnsembleRegressor':
'https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md'
'#ai.onnx.ml.TreeEnsembleRegressor',
Expand Down
4 changes: 2 additions & 2 deletions _doc/sphinxdoc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ operators.
.. toctree::
:maxdepth: 1

tutorial
tutorial/index
doc
api
api/index
auto_examples/index
dev
versions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ model are part of a pipeline.
.. toctree::
:maxdepth: 1

auto_examples/plot_gexternal_lightgbm
auto_examples/plot_gexternal_xgboost
../gyexamples/plot_gexternal_lightgbm
../gyexamples/plot_gexternal_xgboost
Loading