-
Notifications
You must be signed in to change notification settings - Fork 143
allow execute and executeAsync to feed the model with intermediate node #204
Conversation
…des, restrict predict method to only allow feeding the input nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this! Nice work. Minor comments, thus LGTM
Reviewed 7 of 7 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @dsmilkov)
src/executor/frozen_model.ts, line 129 at r1 (raw file):
* * You can also feed any intermediate nodes using the NamedTensorMap as the * input type. For example a graph as following:
You can say all this in a single sentence. How about:
"For example, given the graph InputNode => Intermediate => OutputNode, you can execute the subgraph Intermediate => OutputNode by calling frozenModel.execute('IntermediateNode' : tf.tensor(...))
"
src/executor/graph_executor.ts, line 112 at r1 (raw file):
// do nothing is the compiled graph cache contains the input. if (this.compiledMap.get(sortedNodeNames.join(this.SEPERATOR))) return;
nit about style guide:
always use {} with if statements:if(...) { return; }
src/executor/graph_executor.ts, line 129 at r1 (raw file):
}); } this.compiledMap.set(sortedNodeNames.join(this.SEPERATOR), compiledOrder);
put sortedNodeNames.join(this.SEPARATOR) in a separate variable so you can reuse it in the two places (above when you terminate early and here).
src/executor/graph_executor.ts, line 164 at r1 (raw file):
executeOp(node, tensorMap, context) as Tensor[]; } // stop the exuection if all outputs are found.
typo:execution
src/executor/graph_executor.ts, line 165 at r1 (raw file):
} // stop the exuection if all outputs are found. if (outputNames.every(name => !!tensorMap[name])) break;
style guide nit: wrap break in {}
src/executor/graph_executor.ts, line 333 at r1 (raw file):
this.placeholders.forEach(node => { const inputTensors = inputs[node.name]; // do nothing if no strick input check and input tensors is not for the
typo: strict. Also placeholder.
src/executor/graph_executor.ts, line 335 at r1 (raw file):
// do nothing if no strick input check and input tensors is not for the // placehloder. if (!strictInputCheck && !inputTensors) return;
wrap return in {}
src/operations/operation_mapper.ts, line 84 at r1 (raw file):
let withDynamicShape = false; const placeholders: Node[] = []; const weights: Node[] = [];
curious, was not having the weights a bug before this change? Or did the intermediate inputs change made this required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the review.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @pyu10055)
src/executor/frozen_model.ts, line 129 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
You can say all this in a single sentence. How about:
"For example, given the graph InputNode => Intermediate => OutputNode, you can execute the subgraph Intermediate => OutputNode by calling frozenModel.execute('IntermediateNode' : tf.tensor(...))
"
done
src/operations/operation_mapper.ts, line 84 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
curious, was not having the weights a bug before this change? Or did the intermediate inputs change made this required?
the weights were part of the inputs nodes, since it does not depends on other nodes.
when we need to be able to feed the intermediate nodes, we need to explicitly add them, that is why they are singled out here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Reviewed 2 of 2 files at r2.
Reviewable status: complete! 1 of 1 approvals obtained
Allow feeding the frozen model with intermediate nodes for prediction, which is typical for dynamic_rnn model. fixes tensorflow/tfjs#532
The execute and executeAsync methods will allow feeding intermediate nodes, while predict method will stay consistent with layers API that only allows feeding input nodes.
This change is