-
Notifications
You must be signed in to change notification settings - Fork 143
Logical Ops: Map Where Class to tfc.where #174
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this,
C++ where op is difference from python where op, https://www.tensorflow.org/api_docs/python/tf/where
https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/where
If you look at the python code, it actually calls c++ where or select
https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/python/ops/array_ops.py#L2601
tfjs implementation supports only select, the direct mapping from TF c++ where op to tfjs where op will not work.
Also, besides the json file update, you need to change the executor to add the call to tfjs.where.
Reviewed 2 of 2 files at r1.
Reviewable status: 0 of 1 approvals obtained
@pyu10055 Thanks for sharing. Looks like tfjs-core's
This seems to be implemented already here. Please correct me if I'm wrong. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One possible solution is to update tfjs where op to allow input params as undefined, and it should return the same result as c++ where. This would require changes from tfjs-core and tfjs-node.
Reviewable status: 0 of 1 approvals obtained
With tensorflow/tfjs-core#1179 getting merged, we can now map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. Manraj, do you plan on making that update in this PR? cc @pyu10055
Reviewable status: 0 of 1 approvals obtained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, didn't mean to approve yet (clicked the wrong button)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you Manraj! One small comment.
Reviewable status: 0 of 1 approvals obtained (waiting on @manrajgrover)
docs/supported_ops.md, line 155 at r2 (raw file):
|NotEqual|notEqual| |Select|where| |Where|whereAsync|
please add the Where Op to the list of dynamic shape ops.
https://github.com/tensorflow/tfjs-converter/pull/183/files#diff-724a90795232cc272e38c11d831f22e6L38
so the user would not accidentally call the non async predict for graph with Where op.
Sorry that you would need to wait for the PR#183 to be committed, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 4 of 4 files at r2.
Reviewable status: 0 of 1 approvals obtained (waiting on @manrajgrover)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @manrajgrover)
src/operations/executors/logical_executor.ts, line 29 at r2 (raw file):
export let executeOp: OpExecutor = async( node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): Promise<tfc.Tensor[]> => {
Hi , sorry. It turns out this is problematic since it will make the method always return a promise, even for non-dynamic ops. See discussion on #183. The solution is to move this to dynamic_executor.ts that just got added in #183
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @manrajgrover)
src/operations/executors/logical_executor.ts, line 81 at r2 (raw file):
getParamValue('b', node, tensorMap, context) as tfc.Tensor)]; } case 'whereAsync': {
please move this to dynamic_executor.ts, since this op returns Promise<Tensor[]>, which other ops return Tensor[], making the method async would cause all return value to be Promise<Tensor[]>.
src/operations/op_list/logical.json, line 233 at r2 (raw file):
}, { "tfOpName": "Where",
Aslo move this to dynamic.json, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM!
Reviewed 6 of 7 files at r3.
Reviewable status: 0 of 1 approvals obtained (waiting on @manrajgrover)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @manrajgrover)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work Manraj!
Reviewed 1 of 4 files at r2, 6 of 7 files at r3.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @manrajgrover)
This PR maps TensorFlow's
Where
class totfc.where
This change is