Skip to content

Commit

Permalink
[Onnx] Add momentum (apache#9000)
Browse files Browse the repository at this point in the history
* add momentum

* make tests pass for momentum

* blacking

* lint

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 4dc7318 commit bf413c2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
54 changes: 54 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3707,6 +3707,59 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Momentum(OnnxOpConverter):
"""Operator converter for Momentum op."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = attr["alpha"]
beta = attr["beta"]
mode = attr["mode"].decode("utf-8")
norm_coefficient = attr["norm_coefficient"]

assert mode in ["nesterov", "standard"], f"Unknown momentum mode {mode}"
R = inputs[0]
T = inputs[1]

assert (
len(inputs) - 2
) % 3 == 0, f"Expect triplets for remaining inputs, found {len(inputs) - 2}"
# Remaining inputs are:
# [x_1, x_2 ..., x_1_gradient, x_2_gradient, ... x_1_momentum, x_2_momentum...]
num_input_tensors = (len(inputs) - 2) // 3

# convert attributes to constants
dtype_inputs = infer_type(inputs[3]).checked_type.dtype
alpha = relay.const(alpha, dtype=dtype_inputs)
beta = relay.const(beta, dtype=dtype_inputs)
norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs)
default_beta = relay.const(1.0, dtype=dtype_inputs)

# Calculate updated values for every input
output_tensors = []
output_momentums = []
for i in range(num_input_tensors):
x = inputs[i + 2]
gradient = inputs[i + 2 + num_input_tensors]
momentum = inputs[i + 2 + 2 * num_input_tensors]
g_regularized = norm_coefficient * x + gradient
beta_adjusted = relay.If(T > relay.const(0, dtype="int64"), beta, default_beta)
new_momentum = alpha * momentum + beta_adjusted * g_regularized

if mode == "standard":
x_output = x - R * new_momentum
else:
# mode == 'nesterov'
x_output = x - R * (g_regularized + alpha * new_momentum)

output_tensors.append(x_output)
output_momentums.append(new_momentum)

# append lists together, momentums come after result tensors
result = output_tensors + output_momentums
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3895,6 +3948,7 @@ def _get_convert_map(opset):
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
"Adagrad": Adagrad.get_converter(opset),
"Adam": Adam.get_converter(opset),
"Momentum": Momentum.get_converter(opset),
}


Expand Down
3 changes: 0 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4755,10 +4755,7 @@ def verify_eyelike(indata):
"test_maxpool_with_argmax_2d_precomputed_pads",
"test_maxpool_with_argmax_2d_precomputed_strides",
"test_maxunpool_export_with_output_shape",
"test_momentum",
"test_momentum_multiple",
"test_mvn",
"test_nesterov_momentum",
# When unsqueeze is fully supported, remaining nllloss tests should work:
"test_nllloss_NC_expanded",
"test_nllloss_NCd1_expanded",
Expand Down

0 comments on commit bf413c2

Please sign in to comment.