Skip to content
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ __pycache__/
*.h5
*.npy
*.pth
*.log
dist/
*.DS_Store
.python-version
test/models/custom_conversion_tests

.vscode/
gallery_models
.DS_Store
2 changes: 1 addition & 1 deletion onnx2kerastl/customonnxlayer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from onnx2kerastl.customonnxlayer.onnxlstm import OnnxLSTM

onnx_custom_objects_map = {
"OnnxLSTM": OnnxLSTM
"OnnxLSTM": OnnxLSTM,
}

onnx_custom_layers = {
Expand Down
163 changes: 141 additions & 22 deletions onnx2kerastl/elementwise_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import keras
import logging

from .utils import is_numpy, ensure_tf_type
from .tfops_funcs import tf_tensor_scatter_nd_update, tf_maximum, tf_minimum, tf_cast, tf_expand_dims, tf_repeat,\
tf_equal, tf_where, tf_round, tf_sign, tf_abs, tf_math_mod, tf_bitwise_left_shift, tf_bitwise_right_shift,\
Expand Down Expand Up @@ -71,33 +72,71 @@ def convert_elementwise_add(node, params, layers, lambda_func, node_name, keras_
logger = logging.getLogger('onnx2keras.add')

if len(node.input) != 2:
raise AttributeError('Number of inputs is not equal 2 for element-wise layer')
raise AttributeError('Number of inputs is not equal to 2 for element-wise layer')

input_0 = layers[node.input[0]]
input_1 = layers[node.input[1]]

input_0_is_non_keras = is_numpy(input_0) or isinstance(input_0, EagerTensor)
input_1_is_non_keras = is_numpy(input_1) or isinstance(input_1, EagerTensor)
input_0_is_constant = is_numpy(input_0) or isinstance(input_0, EagerTensor)
input_1_is_constant = is_numpy(input_1) or isinstance(input_1, EagerTensor)

try:
if not input_0_is_non_keras and not input_1_is_non_keras:
to_add = input_1
# We probably need to seperate two possibilities here. Currently we only deal with second option
# [Batch] + [Batch,1] -> [Batch,1]
# [Not-Batch] + [Not,Batch,1] -> [Not-batch, Not-batch]
if not input_0_is_constant and not input_1_is_constant:
# Both inputs are variables
if len(input_0.shape) != len(input_1.shape):
layers[node_name] = tf_add(input_0, to_add, tf_name=f"{params['cleaned_name']}_add")
# Use TensorFlow add to handle shape differences
layers[node_name] = tf_add(input_0, input_1, tf_name=f"{params['cleaned_name']}_add")
else:
layers[node_name] = keras.layers.Add(name=f"{params['cleaned_name']}_add")([input_0, to_add])
# Use Keras Add layer
layers[node_name] = keras.layers.Add(name=f"{params['cleaned_name']}_add")([input_0, input_1])
else:
raise ValueError('Operands are different.')
except (IndexError, ValueError):
logger.warning('Failed to use keras.layers.Add. Fallback to TF lambda.')
if input_0_is_non_keras:
layers[node_name] = input_1 + input_0
logger.warning('Failed to use keras.layers.Add. Fallback to Lambda layer.')

if input_0_is_constant and not input_1_is_constant:
# input_0 is constant, input_1 is variable
constant_value = np.asarray(tf.cast(input_0, dtype=input_1.dtype))
variable_input = input_1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we end the if condition here and then take the next part and combine it to avoid code duplication?

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: x + const_val,
name=keras_name
)(variable_input)
else:
# Embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: x + constant_value,
name=keras_name
)(variable_input)

elif not input_0_is_constant and input_1_is_constant:
# input_0 is variable, input_1 is constant
constant_value = np.asarray(tf.cast(input_1, dtype=input_0.dtype))
variable_input = input_0

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: x + const_val,
name=keras_name
)(variable_input)
else:
# Embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: x + constant_value,
name=keras_name
)(variable_input)
else:
# Both inputs are constants; compute the result now
layers[node_name] = input_0 + input_1



def convert_elementwise_mul(node, params, layers, lambda_func, node_name, keras_name):
"""
Convert element-wise mul.
Expand All @@ -112,13 +151,14 @@ def convert_elementwise_mul(node, params, layers, lambda_func, node_name, keras_
logger = logging.getLogger('onnx2keras.mul')

if len(node.input) != 2:
raise AttributeError('Number of inputs is not equal 2 for element-wise layer')
raise AttributeError('Number of inputs is not equal to 2 for element-wise layer')

input_0 = layers[node.input[0]]
input_1 = layers[node.input[1]]

input_0_is_constant = is_numpy(input_0) or isinstance(input_0, EagerTensor)
input_1_is_constant = is_numpy(input_1) or isinstance(input_1, EagerTensor)

try:
if not input_0_is_constant and not input_1_is_constant:
mul = keras.layers.Multiply(name=f"{params['cleaned_name']}_mul")
Expand All @@ -127,8 +167,48 @@ def convert_elementwise_mul(node, params, layers, lambda_func, node_name, keras_
raise ValueError('Operands are different.')

except (IndexError, ValueError):
logger.warning('Failed to use keras.layers.Multiply. Fallback to TF lambda.')
layers[node_name] = input_0 * input_1
logger.warning('Failed to use keras.layers.Multiply. Fallback to Lambda layer.')

if input_0_is_constant and not input_1_is_constant:
# input_0 is constant, input_1 is variable
constant_value = np.asarray(tf.cast(input_0, dtype=input_1.dtype))
variable_input = input_1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we end the if here?

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: x * const_val,
name=keras_name
)(variable_input)
else:
# Cannot avoid embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: x * constant_value,
name=keras_name
)(variable_input)

elif not input_0_is_constant and input_1_is_constant:
# input_0 is variable, input_1 is constant
constant_value = np.asarray(tf.cast(input_1, dtype=input_0.dtype))
variable_input = input_0

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: x * const_val,
name=keras_name
)(variable_input)
else:
# Cannot avoid embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: x * constant_value,
name=keras_name
)(variable_input)
else:
# Both inputs are constants; compute the result now
layers[node_name] = input_0 * input_1


def convert_elementwise_sub(node, params, layers, lambda_func, node_name, keras_name):
Expand All @@ -149,24 +229,63 @@ def convert_elementwise_sub(node, params, layers, lambda_func, node_name, keras_

input_0 = layers[node.input[0]]
input_1 = layers[node.input[1]]
input_0_is_np = is_numpy(input_0) or isinstance(input_0, EagerTensor)
input_1_is_np = is_numpy(input_1) or isinstance(input_1, EagerTensor)

input_0_is_constant = is_numpy(input_0) or isinstance(input_0, EagerTensor)
input_1_is_constant = is_numpy(input_1) or isinstance(input_1, EagerTensor)

try:
if not input_0_is_np and not input_1_is_np:
if not input_0_is_constant and not input_1_is_constant:
sub = keras.layers.Subtract(name=f"{params['cleaned_name']}_sub")
layers[node_name] = sub([input_0, input_1])
else:
raise ValueError('Operands are different.')

except (IndexError, ValueError):
logger.warning('Failed to use keras.layers.Subtract. Fallback to TF lambda.')
if input_0_is_np and not input_1_is_np: # constant - tensor does not parse well
layers[node_name] = - (input_1 - input_0)
logger.warning('Failed to use keras.layers.Subtract. Fallback to Lambda layer.')

if input_0_is_constant and not input_1_is_constant:
# input_0 is constant, input_1 is variable: constant - variable
constant_value = np.asarray(tf.cast(input_0, dtype=input_1.dtype))
variable_input = input_1

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: const_val - x,
name=keras_name
)(variable_input)
else:
# Cannot avoid embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: constant_value - x,
name=keras_name
)(variable_input)

elif not input_0_is_constant and input_1_is_constant:
# input_0 is variable, input_1 is constant: variable - constant
constant_value = np.asarray(tf.cast(input_1, dtype=input_0.dtype))
variable_input = input_0

if np.all(constant_value == constant_value.flat[0]):
# Constant tensor has the same value throughout
const_val = constant_value.flat[0]
layers[node_name] = keras.layers.Lambda(
lambda x: x - const_val,
name=keras_name
)(variable_input)
else:
# Cannot avoid embedding the constant tensor
layers[node_name] = keras.layers.Lambda(
lambda x: x - constant_value,
name=keras_name
)(variable_input)
else:
# Both inputs are constants; compute the result now
layers[node_name] = input_0 - input_1



def convert_min(node, params, layers, lambda_func, node_name, keras_name):
"""
Convert Min layer
Expand Down
26 changes: 15 additions & 11 deletions onnx2kerastl/padding_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,22 @@ def convert_padding(node, params, layers, lambda_func, node_name, keras_name):
)
layers[node_name] = padding_layer(input_0)
elif params['mode'] == 'reflect':
if pads.shape[0] == 6:
result = tf_pad(input_0, [[pads[0], pads[3]], [pads[1], pads[4]], [pads[2], pads[5]]], mode='REFLECT',
tf_name=f"{params['cleaned_name']}_reflect_pad")
layers[node_name] = result
else:
def target_layer(x, pads=pads):
if pads.shape[0] == 8:
layer = tf.pad(x, [[0, 0], [0, 0], [pads[2], pads[6]], [pads[3], pads[7]]], 'REFLECT')
else:
logger.warning("Caution - no test yet")
layer = tf.pad(x, [[0, 0], [0, 0], [pads[2], pads[7]], [pads[3], pads[8]], [pads[4], pads[9]]], 'REFLECT')
return layer

def target_layer(x, pads=pads):
if pads.shape[0] == 8:
layer = tf.pad(x, [[0, 0], [0, 0], [pads[2], pads[6]], [pads[3], pads[7]]], 'REFLECT')
else:
logger.warning("Caution - no test yet")
layer = tf.pad(x, [[0, 0], [0, 0], [pads[2], pads[7]], [pads[3], pads[8]], [pads[4], pads[9]]], 'REFLECT')
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=f"{params['cleaned_name']}_pad_reflect")
layers[node_name] = lambda_layer(input_0)
lambda_func[keras_name] = target_layer
lambda_layer = keras.layers.Lambda(target_layer, name=f"{params['cleaned_name']}_pad_reflect")
layers[node_name] = lambda_layer(input_0)
lambda_func[keras_name] = target_layer
elif params['mode'] == 'edge':

def target_layer(x, pads=pads):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "onnx2kerastl"
version = "0.0.149"
version = "0.0.152"
description = ""
authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
license = "MIT"
Expand Down
Loading