Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion onnx2kerastl/elementwise_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,11 @@ def convert_where(node, params, layers, lambda_func, node_name, keras_name):
layers[node.input[1]],
tf_name=f"{params['cleaned_name']}_where_1")
else:
layers[node_name] = tf_where(casted, layers[node.input[1]], layers[node.input[2]],
try:
layers[node_name] = tf_where(casted, layers[node.input[1]], layers[node.input[2]],
tf_name=f"{params['cleaned_name']}_where_2")
except Exception as e:
print(1)


def convert_scatter_nd(node, params, layers, lambda_func, node_name, keras_name):
Expand Down
3 changes: 2 additions & 1 deletion onnx2kerastl/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
convert_tile, convert_gather_elements
from .constant_layers import convert_constant, convert_constant_of_shape, convert_one_hot
from .normalization_layers import convert_batchnorm, convert_instancenorm, convert_dropout, convert_lrn, convert_layernorm
from .pooling_layers import convert_avgpool, convert_maxpool, convert_global_avg_pool, convert_topk, convert_roi_align
from .pooling_layers import convert_avgpool, convert_global_max_pool, convert_maxpool, convert_global_avg_pool, convert_topk, convert_roi_align
from .padding_layers import convert_padding
from .upsampling_layers import convert_upsample
from .caffe2_layers import convert_alias_with_name, convert_resize_nearest
Expand Down Expand Up @@ -81,6 +81,7 @@
'Dropout': convert_dropout,
'LRN': convert_lrn,
'MaxPool': convert_maxpool,
'GlobalMaxPool': convert_global_max_pool,
'AveragePool': convert_avgpool,
'GlobalAveragePool': convert_global_avg_pool,
'Shape': convert_shape,
Expand Down
33 changes: 33 additions & 0 deletions onnx2kerastl/pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,39 @@ def convert_maxpool(node, params, layers, lambda_func, node_name, keras_name):
input_0 = layers[node_name + "_pre_" + rand_string]
layers[node_name] = pooling(input_0)

def convert_global_max_pool(node, params, layers, lambda_func, node_name, keras_name):
"""
Convert GlobalMaxPool layer
:param node: current operation node
:param params: operation attributes
:param layers: available keras layers
:param lambda_func: function for keras Lambda layer
:param node_name: internal converter name
:param keras_name: resulting layer name
:return: None
"""
input_0 = ensure_tf_type(layers[node.input[0]], name="%s_const" % keras_name)
tensor_dim = len(input_0.shape)
if tensor_dim == 3:
global_pool = keras.layers.GlobalMaxPooling1D(data_format='channels_first',
name=f"{params['cleaned_name']}_global_max_pool_3")
elif tensor_dim == 4:
global_pool = keras.layers.GlobalMaxPooling2D(data_format='channels_first',
name=f"{params['cleaned_name']}_global_max_pool_4")
elif tensor_dim == 5:
global_pool = keras.layers.GlobalMaxPooling3D(data_format='channels_first',
name=f"{params['cleaned_name']}_global_max_pool_5")
else:
raise NotImplementedError("Global max pooling of dims < 3 or dims > 5 is not supported")
input_0 = global_pool(input_0)
new_shape = input_0.shape.as_list()
new_shape = new_shape[1:]
new_shape.extend([1] * (tensor_dim - 2))
reshape_layer = keras.layers.Reshape(new_shape, name=f"{params['cleaned_name']}_global_max_pool_reshape")
input_0 = reshape_layer(input_0)

layers[node_name] = input_0


def convert_avgpool(node, params, layers, lambda_func, node_name, keras_name):
"""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "onnx2kerastl"
version = "0.0.157"
version = "0.0.158"
description = ""
authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
license = "MIT"
Expand Down
Loading