diff --git a/onnx2keras/converter.py b/onnx2keras/converter.py index 616295cf..b057e2ba 100644 --- a/onnx2keras/converter.py +++ b/onnx2keras/converter.py @@ -17,6 +17,7 @@ def onnx_node_attributes_to_dict(args): :param args: ONNX attributes object :return: Python dictionary """ + def onnx_attribute_to_dict(onnx_attr): """ Parse ONNX attribute @@ -33,6 +34,7 @@ def onnx_attribute_to_dict(onnx_attr): for attr_type in ['floats', 'ints', 'strings']: if getattr(onnx_attr, attr_type): return list(getattr(onnx_attr, attr_type)) + return {arg.name: onnx_attribute_to_dict(arg) for arg in args} @@ -89,8 +91,8 @@ def onnx_to_keras(onnx_model, input_names, weights[onnx_extracted_weights_name] = numpy_helper.to_array(onnx_w) logger.debug('Found weight {0} with shape {1}.'.format( - onnx_extracted_weights_name, - weights[onnx_extracted_weights_name].shape)) + onnx_extracted_weights_name, + weights[onnx_extracted_weights_name].shape)) layers = dict() lambda_funcs = dict() @@ -137,6 +139,7 @@ def onnx_to_keras(onnx_model, input_names, postfix = node_index if len(node.output) == 1 else "%s_%s" % (node_index, output_index) keras_names.append('LAYER_%s' % postfix) else: + output = output.replace(":", "_") keras_names.append(output) if len(node.output) != 1: @@ -225,9 +228,9 @@ def onnx_to_keras(onnx_model, input_names, if len(list(layer['config']['target_shape'][1:][:])) > 0: layer['config']['target_shape'] = \ tuple(np.reshape(np.array( - list(layer['config']['target_shape'][1:]) + - [layer['config']['target_shape'][0]] - ), -1),) + list(layer['config']['target_shape'][1:]) + + [layer['config']['target_shape'][0]] + ), -1), ) if layer['config'] and 'data_format' in layer['config']: layer['config']['data_format'] = 'channels_last' diff --git a/onnx2keras/main.py b/onnx2keras/main.py new file mode 100644 index 00000000..1ca66409 --- /dev/null +++ b/onnx2keras/main.py @@ -0,0 +1,16 @@ +import onnx +from onnx2keras import onnx_to_keras + + +def main(): + onnx_file_path = "onnx_models/model.onnx" + + # Load ONNX model + onnx_model = onnx.load(onnx_file_path) + + # Call the converter (input - is the main model input name, can be different for your model) + k_model = onnx_to_keras(onnx_model, ['input_1']) + + +if __name__ == '__main__': + main()