From 5c8f92c5d865bb8d8a90740506e4a7ef9f62398c Mon Sep 17 00:00:00 2001 From: "Yuchi.Wen" Date: Thu, 25 Apr 2019 13:11:04 +0800 Subject: [PATCH] Enable CKPT format for tf importer --- doc/python/command_line_interface.rst | 11 ++- python/src/nnabla/utils/cli/convert.py | 11 ++- python/src/nnabla/utils/converter/commands.py | 9 ++- .../utils/converter/tensorflow/importer.py | 80 ++++++++++--------- 4 files changed, 70 insertions(+), 41 deletions(-) diff --git a/doc/python/command_line_interface.rst b/doc/python/command_line_interface.rst index 139326035..37efc844c 100644 --- a/doc/python/command_line_interface.rst +++ b/doc/python/command_line_interface.rst @@ -324,10 +324,13 @@ File format converter [--nnp-parameter-h5] [--nnp-parameter-nntxt] [--nnp-exclude-parameter] [-T DEFAULT_VARIABLE_TYPE] [-s SETTINGS] [-c CONFIG] [-d DEFINE_VERSION] - FILE [FILE ...] + [--outputs OUTPUTS] [--inputs INPUTS] FILE [FILE ...] positional arguments: FILE File or directory name(s) to convert. + (When convert ckpt format of the tensorflow model, + If the version of the checkpoint is V1, need to enter the `.ckpt` file, + otherwise need to enter the `.meta` file.) optional arguments: -h, --help show this help message and exit @@ -335,6 +338,12 @@ File format converter [import] import format. (one of [NNP,ONNX]) --nnp-no-expand-network [import][NNP] expand network with repeat or recurrent. + --outputs OUTPUTS + [import][tensorflow] The name(s) of the output nodes, comma separated. + Only needed when convert CKPT format. + --inputs INPUTS + [import][tensorflow] The name(s) of the input nodes, comma separated. + Only needed when convert CKPT format. -O EXPORT_FORMAT, --export-format EXPORT_FORMAT [export] export format. (one of [NNP,NNB,CSRC,ONNX]) -f, --force [export] overwrite output file. diff --git a/python/src/nnabla/utils/cli/convert.py b/python/src/nnabla/utils/cli/convert.py index 455b258bf..ecf355c12 100644 --- a/python/src/nnabla/utils/cli/convert.py +++ b/python/src/nnabla/utils/cli/convert.py @@ -115,8 +115,17 @@ def add_import_arg(parser): # Converter subparser = subparsers.add_parser('convert', help='File format converter.') subparser.add_argument('files', metavar='FILE', type=str, nargs='+', - help='File or directory name(s) to convert.') + help='File or directory name(s) to convert. \ + (When convert ckpt format of the tensorflow model, \ + If the version of the checkpoint is V1, need to enter the `.ckpt` file, \ + otherwise need to enter the `.meta` file.)') # import option + subparser.add_argument('--outputs', type=str, default=None, + help='[import][tensorflow] The name(s) of the output nodes, comma separated. \ + Only needed when convert CKPT format.') + subparser.add_argument('--inputs', type=str, default=None, + help='[import][tensorflow] The name(s) of the input nodes, comma separated. \ + Only needed when convert CKPT format.') add_import_arg(subparser) # export option diff --git a/python/src/nnabla/utils/converter/commands.py b/python/src/nnabla/utils/converter/commands.py index 1f07f1947..216c012f2 100644 --- a/python/src/nnabla/utils/converter/commands.py +++ b/python/src/nnabla/utils/converter/commands.py @@ -35,7 +35,9 @@ def _import_file(args, ifiles): elif ext == '.pb': args.import_format = "TF_PB" elif ext == '.ckpt': - args.import_format = "TF_CKPT" + args.import_format = "TF_CKPT_V1" + elif ext == '.meta': + args.import_format = "TF_CKPT_V2" if args.import_format == 'NNP': # Input file that has unsupported extension store into output nnp @@ -49,9 +51,10 @@ def _import_file(args, ifiles): return OnnxImporter(*ifiles).execute() elif args.import_format == 'TF_PB' or \ - args.import_format == 'TF_CKPT': + args.import_format == 'TF_CKPT_V1' or \ + args.import_format == "TF_CKPT_V2": from .tensorflow import TensorflowImporter - return TensorflowImporter(*ifiles, tf_format=args.import_format).execute() + return TensorflowImporter(*ifiles, tf_format=args.import_format, outputs=args.outputs, inputs=args.inputs).execute() return None diff --git a/python/src/nnabla/utils/converter/tensorflow/importer.py b/python/src/nnabla/utils/converter/tensorflow/importer.py index e779b7678..483456d41 100644 --- a/python/src/nnabla/utils/converter/tensorflow/importer.py +++ b/python/src/nnabla/utils/converter/tensorflow/importer.py @@ -16,10 +16,11 @@ from ..onnx import OnnxImporter import tensorflow as tf import tf2onnx +import collections +from tf2onnx import constants, loader from tf2onnx.graph import GraphUtil from tensorflow.core.framework import graph_pb2 from tensorflow.python.tools import freeze_graph -# import pdb def _strip_node_name(name): @@ -33,19 +34,23 @@ def _find_out_terminal_node(graph_def, **kwargs): def add_postfix(names): return ["{}:0".format(n) for n in names] - unlike_output_types = ["Const", "Assign", "Noop", "Placeholder"] + unlike_output_types = ["Const", "Assign", "NoOp", "Placeholder"] terminal_inputs = [] - inputs = set() - outputs = set() + terminal_outputs = [] + input_cnt = collections.Counter() need_add_postfix = kwargs.get("postfix", False) for node in graph_def.node: - strip_name = _strip_node_name(node.name) + for input in node.input: + input = _strip_node_name(input) + input_cnt[input] += 1 if node.op == 'Placeholder': - terminal_inputs.append(_strip_node_name(node.name)) - outputs.add(strip_name) - inputs.update(set(node.input)) - terminal_outputs = list(filter(lambda x: x not in unlike_output_types, - outputs - inputs)) + strip_name = _strip_node_name(node.name) + terminal_inputs.append(strip_name) + + for node in graph_def.node: + if input_cnt[node.name] == 0 and node.op not in unlike_output_types: + terminal_outputs.append(node.name) + if need_add_postfix: terminal_inputs = add_postfix(terminal_inputs) terminal_outputs = add_postfix(terminal_outputs) @@ -59,12 +64,11 @@ class TensorflowImporter: def __init__(self, *args, **kwargs): self._tf_file = args[0] - self._tf_format = kwargs.get("tf_format", "TF_PB") + self._tf_format = kwargs.get("tf_format") + self._outputs = kwargs.get("outputs") + self._inputs = kwargs.get("inputs") - def _import_from_tf_pb(self, graph_def): - inputs, outputs = _find_out_terminal_node(graph_def, postfix=True) - print("inputs:{}".format(inputs)) - print("outputs:{}".format(outputs)) + def convert_to_onnx(self, graph_def, inputs, outputs): # FIXME: folding const = False graph_def = tf2onnx.tfonnx.tf_optimize( @@ -76,30 +80,20 @@ def _import_from_tf_pb(self, graph_def): continue_on_error=False, verbose=False, target=",".join( - tf2onnx.tfonnx.DEFAULT_TARGET), + constants.DEFAULT_TARGET), opset=6, input_names=inputs, output_names=outputs, inputs_as_nchw=None) - model_proto = onnx_graph.make_model("tf_model") - new_model_proto = GraphUtil.opt_transposes_with_graph(onnx_graph, - 'tf_model', - optimize=True) + model_proto = onnx_graph.make_model( + "converted from {}".format(self._tf_file)) + new_model_proto = GraphUtil.optimize_model_proto(model_proto) if new_model_proto: model_proto = new_model_proto return model_proto - def import_from_tf_pb(self): - graph_def = graph_pb2.GraphDef() - with tf.gfile.GFile(self._tf_file, 'rb') as f: - graph_def.ParseFromString(f.read()) - return self._import_from_tf_pb(graph_def) - - def import_from_tf_ckpt(self): + def load_checkpoint_v1(self): ckpt_path = os.path.dirname(self._tf_file) - if not ckpt_path: - raise ValueError( - "check point file should be in a special directory.") latest_ckpt = tf.train.latest_checkpoint(ckpt_path) saver = tf.train.import_meta_graph(latest_ckpt + ".meta") with tf.Session() as session: @@ -115,21 +109,35 @@ def import_from_tf_ckpt(self): input_graph_def=graph_def, input_saver_def=None, input_checkpoint=latest_ckpt, - output_node_names="biases", + output_node_names=self._outputs, restore_op_name="", filename_tensor_name="", output_graph=None, clear_devices=True, initializer_nodes="" ) - onnx_model = self._import_from_tf_pb(frozen_graph) - return onnx_model + return frozen_graph def execute(self): if self._tf_format == 'TF_PB': - onnx_model = self.import_from_tf_pb() - elif self._tf_format == 'TF_CKPT': - onnx_model = self.import_from_tf_ckpt() + graph_def = graph_pb2.GraphDef() + with tf.gfile.GFile(self._tf_file, 'rb') as f: + graph_def.ParseFromString(f.read()) + inputs, outputs = _find_out_terminal_node(graph_def, postfix=True) + else: + if self._outputs is None: + raise ImportError("Missing '--outputs' parameter.") + if self._inputs is None: + raise ImportError("Missing '--inputs' parameter.") + + inputs = [i + ":0" for i in self._inputs.split(",")] + outputs = [i + ":0" for i in self._outputs.split(",")] + if self._tf_format == 'TF_CKPT_V1': + graph_def = self.load_checkpoint_v1() + elif self._tf_format == 'TF_CKPT_V2': + graph_def, inputs, outputs = loader.from_checkpoint( + self._tf_file, inputs, outputs) + onnx_model = self.convert_to_onnx(graph_def, inputs, outputs) onnx_importer = OnnxImporter() onnx_importer.import_from_onnx_model(onnx_model) return onnx_importer.execute()