Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions _doc/bench/bench_orttraining_nn_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from mlprodict.onnx_conv import to_onnx
from onnxcustom.training import add_loss_output, get_train_initializer
from onnxcustom.training.optimizers import OrtGradientOptimizer
from onnxcustom.training import (
add_loss_output, get_train_initializer,
OrtGradientOptimizer)


def benchmark(N=1000, n_features=20, hidden_layer_sizes="25,25", max_iter=1000,
Expand Down
5 changes: 3 additions & 2 deletions _doc/examples/plot_orttraining_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from mlprodict.onnx_conv import to_onnx
from onnxcustom.training import add_loss_output, get_train_initializer
from onnxcustom.training.optimizers import OrtGradientOptimizer
from onnxcustom.training import (
add_loss_output, get_train_initializer,
OrtGradientOptimizer)


X, y = make_regression(2000, n_features=100, bias=2)
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_training/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyquickhelper.pycode import ExtTestCase
from sklearn.datasets import make_regression
from onnxruntime import OrtValue
from onnxcustom.training.data_loader import OrtDataLoader
from onnxcustom.training import OrtDataLoader


class TestDataLoadeer(ExtTestCase):
Expand Down
16 changes: 8 additions & 8 deletions _unittests/ut_training/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class TestOptimizers(ExtTestCase):

@unittest.skipIf(TrainingSession is None, reason="not training")
def test_ort_gradient_optimizers_use_numpy(self):
from onnxcustom.training.orttraining import add_loss_output
from onnxcustom.training.optimizers import OrtGradientOptimizer
from onnxcustom.training import (
add_loss_output, OrtGradientOptimizer)
X, y = make_regression( # pylint: disable=W0632
100, n_features=10, bias=2)
X = X.astype(numpy.float32)
Expand Down Expand Up @@ -78,8 +78,8 @@ def test_ort_gradient_optimizers_use_ort(self):

@unittest.skipIf(TrainingSession is None, reason="not training")
def test_ort_gradient_optimizers_optimal_use_numpy(self):
from onnxcustom.training.orttraining import add_loss_output
from onnxcustom.training.optimizers import OrtGradientOptimizer
from onnxcustom.training import (
add_loss_output, OrtGradientOptimizer)
X, y = make_regression( # pylint: disable=W0632
100, n_features=10, bias=2)
X = X.astype(numpy.float32)
Expand Down Expand Up @@ -134,8 +134,8 @@ def test_ort_gradient_optimizers_optimal_use_ort(self):

@unittest.skipIf(TrainingSession is None, reason="not training")
def test_ort_gradient_optimizers_evaluation_use_numpy(self):
from onnxcustom.training.orttraining import add_loss_output
from onnxcustom.training.optimizers import OrtGradientOptimizer
from onnxcustom.training import (
add_loss_output, OrtGradientOptimizer)
X, y = make_regression( # pylint: disable=W0632
100, n_features=10, bias=2)
X = X.astype(numpy.float32)
Expand Down Expand Up @@ -164,8 +164,8 @@ def test_ort_gradient_optimizers_evaluation_use_numpy(self):

@unittest.skipIf(TrainingSession is None, reason="not training")
def test_ort_gradient_optimizers_evaluation_use_ort(self):
from onnxcustom.training.orttraining import add_loss_output
from onnxcustom.training.optimizers import OrtGradientOptimizer
from onnxcustom.training import (
add_loss_output, OrtGradientOptimizer)
X, y = make_regression( # pylint: disable=W0632
100, n_features=10, bias=2)
X = X.astype(numpy.float32)
Expand Down
4 changes: 3 additions & 1 deletion onnxcustom/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
@file
@brief Shortcuts to *training*.
@brief Shortcuts to *orttraining*.
"""

from .data_loader import OrtDataLoader # noqa
from .optimizers import OrtGradientOptimizer # noqa
from .orttraining import add_loss_output, get_train_initializer # noqa
3 changes: 2 additions & 1 deletion onnxcustom/training/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
@file
@brief Helper for :epkg:`onnxruntime-training`.
@brief Train a machine learned model
with :epkg:`onnxruntime-training`.
"""
import inspect
import numpy
Expand Down
4 changes: 4 additions & 0 deletions onnxcustom/training/ortgradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
@file
@brief Helpers for :epkg:`onnxruntime-training`.
"""
2 changes: 1 addition & 1 deletion onnxcustom/training/orttraining.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
@file
@brief Helper for :epkg:`onnxruntime-training`.
@brief Manipulate ONNX graph to train a model.
"""
from onnx.helper import (
make_node, make_graph, make_model, make_tensor_value_info,
Expand Down