Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
Adds support for Momentum for python runtime (#423)
Browse files Browse the repository at this point in the history
* Adds support for Momentum for python runtime
  • Loading branch information
sdpython committed Apr 15, 2022
1 parent 7dada54 commit fe8e700
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 4 deletions.
42 changes: 40 additions & 2 deletions _unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
Expand Up @@ -54,7 +54,8 @@
OnnxDropout, OnnxDropout_7,
OnnxEinsum, OnnxElu, OnnxEqual, OnnxErf, OnnxExp, OnnxExpand, OnnxEyeLike,
OnnxFlatten, OnnxFloor,
OnnxGreater, OnnxGreaterOrEqual, OnnxGemm, OnnxGlobalAveragePool,
OnnxGreater, OnnxGreaterOrEqual, OnnxGemm,
OnnxGlobalAveragePool, OnnxGlobalMaxPool,
OnnxHardmax, OnnxHardSigmoid, OnnxHardSwish,
OnnxIdentity, OnnxIsInf, OnnxIsNaN,
OnnxLeakyRelu, OnnxLess, OnnxLessOrEqual,
Expand Down Expand Up @@ -106,7 +107,8 @@
_batchnorm_test_mode, _batchnorm_training_mode)
from mlprodict.onnxrt.ops_cpu.op_average_pool import (
_get_output_shape, _pool, _get_pad_shape)
from mlprodict.onnxrt.ops_cpu.op_global_average_pool import _global_average_pool
from mlprodict.onnxrt.ops_cpu.op_global_average_pool import (
_global_average_pool, _global_max_pool)
from mlprodict.onnxrt.ops_cpu._op_onnx_numpy import ( # pylint: disable=E0611,E0401
topk_element_min_double, topk_element_max_double,
topk_element_fetch_double,
Expand Down Expand Up @@ -2890,6 +2892,42 @@ def test_onnxt_runtime_global_average_pool(self):

python_tested.append(OnnxGlobalAveragePool)

@wraplog()
def test_onnxt_runtime_global_max_pool(self):
x = x = numpy.random.randn(1, 3, 5, 5).astype(numpy.float32)
y = _global_max_pool(x).astype(numpy.float32)

onx = OnnxGlobalMaxPool(
'X', output_names=['Y'],
op_version=TARGET_OPSET)
model_def = onx.to_onnx({'X': x.astype(numpy.float32)},
target_opset=TARGET_OPSET)
self._check_shape_inference(OnnxGlobalMaxPool, model_def)
oinf = OnnxInference(model_def)
got = oinf.run({'X': x})
self.assertEqual(list(sorted(got)), ['Y'])
self.assertEqualArray(y, got['Y'])
self.common_expected_shapes_types(
oinf, {'X': x}, got, OnnxGlobalMaxPool, model_def)

x = numpy.array([[[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]]]).astype(numpy.float32)
y = numpy.array([[[[9]]]]).astype(numpy.float32)
onx = OnnxGlobalMaxPool(
'X', output_names=['Y'],
op_version=TARGET_OPSET)
model_def = onx.to_onnx({'X': x.astype(numpy.float32)},
target_opset=TARGET_OPSET)
oinf = OnnxInference(model_def)
got = oinf.run({'X': x})
self.assertEqual(list(sorted(got)), ['Y'])
self.assertEqualArray(y, got['Y'])

python_tested.append(OnnxGlobalMaxPool)

def test_onnxt_runtime_greater(self):
self.common_test_onnxt_runtime_binary(OnnxGreater, numpy.greater)

Expand Down
70 changes: 69 additions & 1 deletion _unittests/ut_onnxrt/test_onnxrt_python_runtime_training.py
Expand Up @@ -6,10 +6,11 @@
import numpy
from pyquickhelper.pycode import ExtTestCase
from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
OnnxAdagrad, OnnxAdam)
OnnxAdagrad, OnnxAdam, OnnxMomentum)
from skl2onnx import __version__ as skl2onnx_version
from onnx.backend.test.case.node.adagrad import apply_adagrad
from onnx.backend.test.case.node.adam import apply_adam
from onnx.backend.test.case.node.momentum import apply_momentum
from mlprodict.onnxrt import OnnxInference
from mlprodict import __max_supported_opset__ as TARGET_OPSET

Expand Down Expand Up @@ -168,6 +169,73 @@ def test_onnxt_runtime_adam_multiple(self):
self.assertEqualArray(v2_new, got['V2_new'], decimal=4)
self.assertEqualArray(h2_new, got['H2_new'], decimal=4)

def test_onnxt_runtime_momentum(self):
norm_coefficient = 0.001
alpha = 0.95
beta = 0.1

r = numpy.array(0.1, dtype=numpy.float32)
t = numpy.array(0, dtype=numpy.int64) # scalar
x = numpy.array([1.2, 2.8], dtype=numpy.float32)
g = numpy.array([-0.94, -2.5], dtype=numpy.float32)
v = numpy.array([1.7, 3.6], dtype=numpy.float32)

node = OnnxMomentum(
'R', 'T', 'X', 'G', 'V',
output_names=['X_new', 'V_new'],
norm_coefficient=norm_coefficient,
alpha=alpha, beta=beta,
domain="ai.onnx.preview.training",
op_version=1)

onx = node.to_onnx({'R': r, 'T': t, 'X': x, 'G': g, 'V': v},
target_opset=TARGET_OPSET)
oinf = OnnxInference(onx)
got = oinf.run({'R': r, 'T': t, 'X': x, 'G': g, 'V': v})

x_new, v_new = apply_momentum(
r, t, x, g, v, norm_coefficient, alpha, beta)
self.assertEqualArray(x_new, got['X_new'])
self.assertEqualArray(v_new, got['V_new'])

def test_onnxt_runtime_momentum_multiple(self):
norm_coefficient = 0.001
alpha = 0.95
beta = 0.85
r = numpy.array(0.1, dtype=numpy.float32) # scalar
t = numpy.array(0, dtype=numpy.int64) # scalar
x1 = numpy.array([1.0], dtype=numpy.float32)
g1 = numpy.array([-1.0], dtype=numpy.float32)
v1 = numpy.array([2.0], dtype=numpy.float32)
x2 = numpy.array([1.0, 2.0], dtype=numpy.float32)
g2 = numpy.array([-1.0, -3.0], dtype=numpy.float32)
v2 = numpy.array([4.0, 1.0], dtype=numpy.float32)

node = OnnxMomentum(
'R', 'T', 'X1', 'X2', 'G1', 'G2', 'V1', 'V2',
output_names=['X1_new', 'X2_new', 'V1_new', 'V2_new'],
norm_coefficient=norm_coefficient,
alpha=alpha, beta=beta,
domain="ai.onnx.preview.training",
op_version=1)

onx = node.to_onnx({'R': r, 'T': t,
'X1': x1, 'G1': g1, 'V1': v1,
'X2': x2, 'G2': g2, 'V2': v2},
target_opset=TARGET_OPSET)
oinf = OnnxInference(onx)
got = oinf.run({'R': r, 'T': t,
'X1': x1, 'G1': g1, 'V1': v1,
'X2': x2, 'G2': g2, 'V2': v2})
x1_new, v1_new = apply_momentum(r, t, x1, g1, v1,
norm_coefficient, alpha, beta)
x2_new, v2_new = apply_momentum(r, t, x2, g2, v2,
norm_coefficient, alpha, beta)
self.assertEqualArray(x1_new, got['X1_new'])
self.assertEqualArray(v1_new, got['V1_new'])
self.assertEqualArray(x2_new, got['X2_new'])
self.assertEqualArray(v2_new, got['V2_new'], decimal=4)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion mlprodict/onnxrt/ops_cpu/_op_list.py
Expand Up @@ -65,7 +65,7 @@
from .op_gathernd import GatherND
from .op_gather_elements import GatherElements
from .op_gemm import Gemm
from .op_global_average_pool import GlobalAveragePool
from .op_global_average_pool import GlobalAveragePool, GlobalMaxPool
from .op_greater import Greater, GreaterOrEqual
from .op_hardmax import Hardmax
from .op_hard_sigmoid import HardSigmoid
Expand All @@ -90,6 +90,7 @@
from .op_mean import Mean
from .op_min import Min
from .op_mod import Mod
from .op_momentum import Momentum
from .op_mul import Mul
from .op_neg import Neg
from .op_negative_log_likelihood_loss import NegativeLogLikelihoodLoss
Expand Down
32 changes: 32 additions & 0 deletions mlprodict/onnxrt/ops_cpu/op_global_average_pool.py
Expand Up @@ -18,6 +18,14 @@ def _global_average_pool(x):
return y


def _global_max_pool(x):
spatial_shape = numpy.ndim(x) - 2
y = x.max(axis=tuple(range(spatial_shape, spatial_shape + 2)))
for _ in range(spatial_shape):
y = numpy.expand_dims(y, -1)
return y


class GlobalAveragePool(OpRun):

def __init__(self, onnx_node, desc=None, **options):
Expand All @@ -40,3 +48,27 @@ def _infer_types(self, x): # pylint: disable=W0221
def _infer_sizes(self, *args): # pylint: disable=W0221
res = self.run(*args)
return (dict(temp=0), ) + res


class GlobalMaxPool(OpRun):

def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
**options)

def _run(self, x): # pylint: disable=W0221
res = _global_max_pool(x)
return (res, )

def _infer_shapes(self, x): # pylint: disable=W0221
if x.shape is None:
return (ShapeObject(None, dtype=x.dtype), )
shape = x.shape[:2] + (1, ) * (len(x.shape) - 2)
return (ShapeObject(shape, dtype=x.dtype), )

def _infer_types(self, x): # pylint: disable=W0221
return (x, )

def _infer_sizes(self, *args): # pylint: disable=W0221
res = self.run(*args)
return (dict(temp=0), ) + res
55 changes: 55 additions & 0 deletions mlprodict/onnxrt/ops_cpu/op_momentum.py
@@ -0,0 +1,55 @@
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
@file
@brief Runtime operator.
"""
from ..shape_object import ShapeObject
from ._op import OpRun


def _apply_momentum(r, t, x, g, v, norm_coefficient, alpha, beta):
# Add gradient of regularization term.
g_regularized = norm_coefficient * x + g
# Coefficient of gradient should be 1 at the first iteration.
beta_adjusted = beta if t > 0 else 1
# Update momentum.
v_new = alpha * v + beta_adjusted * g_regularized
# Apply SG with momentum update rule.
x_new = x - r * v_new
return x_new, v_new


class Momentum(OpRun):

atts = {'alpha': 0,
'beta': 0,
'mode': b'standard',
'norm_coefficient': 0.}

def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
expected_attributes=Momentum.atts,
**options)

def _run(self, *data): # pylint: disable=W0221
if len(data) == 5:
return self._run1(*data)
n = (len(data) - 2) // 3
xs = []
vs = []
for i in range(0, n):
a, b = self._run1(*data[:2], data[2 + i],
data[2 + n + i], data[2 + n * 2 + i])
xs.append(a)
vs.append(b)
return tuple(xs + vs)

def _run1(self, r, t, x, g, v): # pylint: disable=W0221
x_new, v_new = _apply_momentum(
r, t, x, g, v, self.norm_coefficient, self.alpha, self.beta)
return x_new, v_new

def _infer_shapes(self, i, *data): # pylint: disable=W0221
n = (len(data) - 1) // 3
return (ShapeObject(None, i.dtype), ShapeObject(None, i.dtype)) * n

0 comments on commit fe8e700

Please sign in to comment.