From 12984ea3ec6ca723873b3bc4606b0600ba7bff4c Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 5 Nov 2019 17:18:32 -0800 Subject: [PATCH 1/2] remove the nodes that are skipped during batch norm folding for control node inputs --- .../python/tensorflowjs/converters/fold_batch_norms.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py b/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py index 5f3b46759eb..ceff67bc1a8 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py +++ b/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py @@ -227,6 +227,12 @@ def fold_batch_norms(input_graph_def): continue new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) + retained_input = [] + for input in new_node.input: + if not input.startswith('^') or input[1:] not in nodes_to_skip: + retained_input.append(input) + new_node.input[:] = retained_input + result_graph_def.node.extend([new_node]) result_graph_def.node.extend(new_ops) From 14f2db5e8d59ddc0199fdefd0da2cc58adddfcdc Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Wed, 6 Nov 2019 10:53:21 -0800 Subject: [PATCH 2/2] fixed pylint error --- .../python/tensorflowjs/converters/fold_batch_norms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py b/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py index ceff67bc1a8..1fa08127a03 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py +++ b/tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py @@ -228,9 +228,9 @@ def fold_batch_norms(input_graph_def): new_node = node_def_pb2.NodeDef() new_node.CopyFrom(node) retained_input = [] - for input in new_node.input: - if not input.startswith('^') or input[1:] not in nodes_to_skip: - retained_input.append(input) + for input_node in new_node.input: + if not input_node.startswith('^') or input_node[1:] not in nodes_to_skip: + retained_input.append(input_node) new_node.input[:] = retained_input result_graph_def.node.extend([new_node])