Skip to content

Commit

Permalink
Merge pull request #432 from sony/feature/20190425-Enable-CKPT-format…
Browse files Browse the repository at this point in the history
…-for-tf-importer

Enable CKPT format for tf importer
  • Loading branch information
YukioOobuchi committed May 15, 2019
2 parents a66f447 + 5c8f92c commit b4dee59
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 41 deletions.
11 changes: 10 additions & 1 deletion doc/python/command_line_interface.rst
Expand Up @@ -324,17 +324,26 @@ File format converter
[--nnp-parameter-h5] [--nnp-parameter-nntxt]
[--nnp-exclude-parameter] [-T DEFAULT_VARIABLE_TYPE]
[-s SETTINGS] [-c CONFIG] [-d DEFINE_VERSION]
FILE [FILE ...]
[--outputs OUTPUTS] [--inputs INPUTS] FILE [FILE ...]
positional arguments:
FILE File or directory name(s) to convert.
(When convert ckpt format of the tensorflow model,
If the version of the checkpoint is V1, need to enter the `.ckpt` file,
otherwise need to enter the `.meta` file.)
optional arguments:
-h, --help show this help message and exit
-I IMPORT_FORMAT, --import-format IMPORT_FORMAT
[import] import format. (one of [NNP,ONNX])
--nnp-no-expand-network
[import][NNP] expand network with repeat or recurrent.
--outputs OUTPUTS
[import][tensorflow] The name(s) of the output nodes, comma separated.
Only needed when convert CKPT format.
--inputs INPUTS
[import][tensorflow] The name(s) of the input nodes, comma separated.
Only needed when convert CKPT format.
-O EXPORT_FORMAT, --export-format EXPORT_FORMAT
[export] export format. (one of [NNP,NNB,CSRC,ONNX])
-f, --force [export] overwrite output file.
Expand Down
11 changes: 10 additions & 1 deletion python/src/nnabla/utils/cli/convert.py
Expand Up @@ -115,8 +115,17 @@ def add_import_arg(parser):
# Converter
subparser = subparsers.add_parser('convert', help='File format converter.')
subparser.add_argument('files', metavar='FILE', type=str, nargs='+',
help='File or directory name(s) to convert.')
help='File or directory name(s) to convert. \
(When convert ckpt format of the tensorflow model, \
If the version of the checkpoint is V1, need to enter the `.ckpt` file, \
otherwise need to enter the `.meta` file.)')
# import option
subparser.add_argument('--outputs', type=str, default=None,
help='[import][tensorflow] The name(s) of the output nodes, comma separated. \
Only needed when convert CKPT format.')
subparser.add_argument('--inputs', type=str, default=None,
help='[import][tensorflow] The name(s) of the input nodes, comma separated. \
Only needed when convert CKPT format.')
add_import_arg(subparser)

# export option
Expand Down
9 changes: 6 additions & 3 deletions python/src/nnabla/utils/converter/commands.py
Expand Up @@ -35,7 +35,9 @@ def _import_file(args, ifiles):
elif ext == '.pb':
args.import_format = "TF_PB"
elif ext == '.ckpt':
args.import_format = "TF_CKPT"
args.import_format = "TF_CKPT_V1"
elif ext == '.meta':
args.import_format = "TF_CKPT_V2"

if args.import_format == 'NNP':
# Input file that has unsupported extension store into output nnp
Expand All @@ -49,9 +51,10 @@ def _import_file(args, ifiles):
return OnnxImporter(*ifiles).execute()

elif args.import_format == 'TF_PB' or \
args.import_format == 'TF_CKPT':
args.import_format == 'TF_CKPT_V1' or \
args.import_format == "TF_CKPT_V2":
from .tensorflow import TensorflowImporter
return TensorflowImporter(*ifiles, tf_format=args.import_format).execute()
return TensorflowImporter(*ifiles, tf_format=args.import_format, outputs=args.outputs, inputs=args.inputs).execute()
return None


Expand Down
80 changes: 44 additions & 36 deletions python/src/nnabla/utils/converter/tensorflow/importer.py
Expand Up @@ -16,10 +16,11 @@
from ..onnx import OnnxImporter
import tensorflow as tf
import tf2onnx
import collections
from tf2onnx import constants, loader
from tf2onnx.graph import GraphUtil
from tensorflow.core.framework import graph_pb2
from tensorflow.python.tools import freeze_graph
# import pdb


def _strip_node_name(name):
Expand All @@ -33,19 +34,23 @@ def _find_out_terminal_node(graph_def, **kwargs):
def add_postfix(names):
return ["{}:0".format(n) for n in names]

unlike_output_types = ["Const", "Assign", "Noop", "Placeholder"]
unlike_output_types = ["Const", "Assign", "NoOp", "Placeholder"]
terminal_inputs = []
inputs = set()
outputs = set()
terminal_outputs = []
input_cnt = collections.Counter()
need_add_postfix = kwargs.get("postfix", False)
for node in graph_def.node:
strip_name = _strip_node_name(node.name)
for input in node.input:
input = _strip_node_name(input)
input_cnt[input] += 1
if node.op == 'Placeholder':
terminal_inputs.append(_strip_node_name(node.name))
outputs.add(strip_name)
inputs.update(set(node.input))
terminal_outputs = list(filter(lambda x: x not in unlike_output_types,
outputs - inputs))
strip_name = _strip_node_name(node.name)
terminal_inputs.append(strip_name)

for node in graph_def.node:
if input_cnt[node.name] == 0 and node.op not in unlike_output_types:
terminal_outputs.append(node.name)

if need_add_postfix:
terminal_inputs = add_postfix(terminal_inputs)
terminal_outputs = add_postfix(terminal_outputs)
Expand All @@ -59,12 +64,11 @@ class TensorflowImporter:

def __init__(self, *args, **kwargs):
self._tf_file = args[0]
self._tf_format = kwargs.get("tf_format", "TF_PB")
self._tf_format = kwargs.get("tf_format")
self._outputs = kwargs.get("outputs")
self._inputs = kwargs.get("inputs")

def _import_from_tf_pb(self, graph_def):
inputs, outputs = _find_out_terminal_node(graph_def, postfix=True)
print("inputs:{}".format(inputs))
print("outputs:{}".format(outputs))
def convert_to_onnx(self, graph_def, inputs, outputs):

# FIXME: folding const = False
graph_def = tf2onnx.tfonnx.tf_optimize(
Expand All @@ -76,30 +80,20 @@ def _import_from_tf_pb(self, graph_def):
continue_on_error=False,
verbose=False,
target=",".join(
tf2onnx.tfonnx.DEFAULT_TARGET),
constants.DEFAULT_TARGET),
opset=6,
input_names=inputs,
output_names=outputs,
inputs_as_nchw=None)
model_proto = onnx_graph.make_model("tf_model")
new_model_proto = GraphUtil.opt_transposes_with_graph(onnx_graph,
'tf_model',
optimize=True)
model_proto = onnx_graph.make_model(
"converted from {}".format(self._tf_file))
new_model_proto = GraphUtil.optimize_model_proto(model_proto)
if new_model_proto:
model_proto = new_model_proto
return model_proto

def import_from_tf_pb(self):
graph_def = graph_pb2.GraphDef()
with tf.gfile.GFile(self._tf_file, 'rb') as f:
graph_def.ParseFromString(f.read())
return self._import_from_tf_pb(graph_def)

def import_from_tf_ckpt(self):
def load_checkpoint_v1(self):
ckpt_path = os.path.dirname(self._tf_file)
if not ckpt_path:
raise ValueError(
"check point file should be in a special directory.")
latest_ckpt = tf.train.latest_checkpoint(ckpt_path)
saver = tf.train.import_meta_graph(latest_ckpt + ".meta")
with tf.Session() as session:
Expand All @@ -115,21 +109,35 @@ def import_from_tf_ckpt(self):
input_graph_def=graph_def,
input_saver_def=None,
input_checkpoint=latest_ckpt,
output_node_names="biases",
output_node_names=self._outputs,
restore_op_name="",
filename_tensor_name="",
output_graph=None,
clear_devices=True,
initializer_nodes=""
)
onnx_model = self._import_from_tf_pb(frozen_graph)
return onnx_model
return frozen_graph

def execute(self):
if self._tf_format == 'TF_PB':
onnx_model = self.import_from_tf_pb()
elif self._tf_format == 'TF_CKPT':
onnx_model = self.import_from_tf_ckpt()
graph_def = graph_pb2.GraphDef()
with tf.gfile.GFile(self._tf_file, 'rb') as f:
graph_def.ParseFromString(f.read())
inputs, outputs = _find_out_terminal_node(graph_def, postfix=True)
else:
if self._outputs is None:
raise ImportError("Missing '--outputs' parameter.")
if self._inputs is None:
raise ImportError("Missing '--inputs' parameter.")

inputs = [i + ":0" for i in self._inputs.split(",")]
outputs = [i + ":0" for i in self._outputs.split(",")]
if self._tf_format == 'TF_CKPT_V1':
graph_def = self.load_checkpoint_v1()
elif self._tf_format == 'TF_CKPT_V2':
graph_def, inputs, outputs = loader.from_checkpoint(
self._tf_file, inputs, outputs)
onnx_model = self.convert_to_onnx(graph_def, inputs, outputs)
onnx_importer = OnnxImporter()
onnx_importer.import_from_onnx_model(onnx_model)
return onnx_importer.execute()

0 comments on commit b4dee59

Please sign in to comment.