Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions onnx2keras/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def onnx_node_attributes_to_dict(args):
:param args: ONNX attributes object
:return: Python dictionary
"""

def onnx_attribute_to_dict(onnx_attr):
"""
Parse ONNX attribute
Expand All @@ -33,6 +34,7 @@ def onnx_attribute_to_dict(onnx_attr):
for attr_type in ['floats', 'ints', 'strings']:
if getattr(onnx_attr, attr_type):
return list(getattr(onnx_attr, attr_type))

return {arg.name: onnx_attribute_to_dict(arg) for arg in args}


Expand Down Expand Up @@ -89,8 +91,8 @@ def onnx_to_keras(onnx_model, input_names,
weights[onnx_extracted_weights_name] = numpy_helper.to_array(onnx_w)

logger.debug('Found weight {0} with shape {1}.'.format(
onnx_extracted_weights_name,
weights[onnx_extracted_weights_name].shape))
onnx_extracted_weights_name,
weights[onnx_extracted_weights_name].shape))

layers = dict()
lambda_funcs = dict()
Expand Down Expand Up @@ -137,6 +139,7 @@ def onnx_to_keras(onnx_model, input_names,
postfix = node_index if len(node.output) == 1 else "%s_%s" % (node_index, output_index)
keras_names.append('LAYER_%s' % postfix)
else:
output = output.replace(":", "_")
keras_names.append(output)

if len(node.output) != 1:
Expand Down Expand Up @@ -225,9 +228,9 @@ def onnx_to_keras(onnx_model, input_names,
if len(list(layer['config']['target_shape'][1:][:])) > 0:
layer['config']['target_shape'] = \
tuple(np.reshape(np.array(
list(layer['config']['target_shape'][1:]) +
[layer['config']['target_shape'][0]]
), -1),)
list(layer['config']['target_shape'][1:]) +
[layer['config']['target_shape'][0]]
), -1), )

if layer['config'] and 'data_format' in layer['config']:
layer['config']['data_format'] = 'channels_last'
Expand Down
16 changes: 16 additions & 0 deletions onnx2keras/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import onnx
from onnx2keras import onnx_to_keras


def main():
onnx_file_path = "onnx_models/model.onnx"

# Load ONNX model
onnx_model = onnx.load(onnx_file_path)

# Call the converter (input - is the main model input name, can be different for your model)
k_model = onnx_to_keras(onnx_model, ['input_1'])


if __name__ == '__main__':
main()