From 27511f59323c0aa3c06b14a9e53386d54c4b12b8 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 27 Apr 2020 15:55:32 -0700 Subject: [PATCH] added support select v2 op --- .../python/tensorflowjs/op_list/logical.json | 29 +++++++++++++++++++ .../operations/executors/logical_executor.ts | 9 +++--- .../executors/logical_executor_test.ts | 13 +++++++++ .../src/operations/op_list/logical.ts | 12 ++++++++ 4 files changed, 59 insertions(+), 4 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/op_list/logical.json b/tfjs-converter/python/tensorflowjs/op_list/logical.json index 669de34a42d..90c1ff6c6c0 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/logical.json +++ b/tfjs-converter/python/tensorflowjs/op_list/logical.json @@ -238,5 +238,34 @@ "notSupported": true } ] + }, + { + "tfOpName": "SelectV2", + "category": "logical", + "inputs": [ + { + "start": 0, + "name": "condition", + "type": "tensor" + }, + { + "start": 1, + "name": "a", + "type": "tensor" + }, + { + "start": 2, + "name": "b", + "type": "tensor" + } + ], + "attrs": [ + { + "tfName": "T", + "name": "dtype", + "type": "dtype", + "notSupported": true + } + ] } ] \ No newline at end of file diff --git a/tfjs-converter/src/operations/executors/logical_executor.ts b/tfjs-converter/src/operations/executors/logical_executor.ts index d56321373cb..c089b806d82 100644 --- a/tfjs-converter/src/operations/executors/logical_executor.ts +++ b/tfjs-converter/src/operations/executors/logical_executor.ts @@ -24,9 +24,9 @@ import {InternalOpExecutor, Node} from '../types'; import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, - tensorMap: NamedTensorsMap, - context: ExecutionContext): - tfc.Tensor[] => { + tensorMap: NamedTensorsMap, + context: ExecutionContext): + tfc.Tensor[] => { switch (node.op) { case 'Equal': { return [tfc.equal( @@ -72,7 +72,8 @@ export const executeOp: InternalOpExecutor = (node: Node, getParamValue('a', node, tensorMap, context) as tfc.Tensor, getParamValue('b', node, tensorMap, context) as tfc.Tensor)]; } - case 'Select': { + case 'Select': + case 'SelectV2': { return [tfc.where( getParamValue('condition', node, tensorMap, context) as tfc.Tensor, getParamValue('a', node, tensorMap, context) as tfc.Tensor, diff --git a/tfjs-converter/src/operations/executors/logical_executor_test.ts b/tfjs-converter/src/operations/executors/logical_executor_test.ts index 6260d558786..1f641666cd2 100644 --- a/tfjs-converter/src/operations/executors/logical_executor_test.ts +++ b/tfjs-converter/src/operations/executors/logical_executor_test.ts @@ -76,5 +76,18 @@ describe('logical', () => { expect(tfc.where).toHaveBeenCalledWith(input3[0], input1[0], input2[0]); }); }); + + describe('SelectV2', () => { + it('should call tfc.where', () => { + spyOn(tfc, 'where'); + node.op = 'SelectV2'; + node.inputNames = ['input1', 'input2', 'input3']; + node.inputParams.condition = createTensorAttr(2); + const input3 = [tfc.scalar(1)]; + executeOp(node, {input1, input2, input3}, context); + + expect(tfc.where).toHaveBeenCalledWith(input3[0], input1[0], input2[0]); + }); + }); }); }); diff --git a/tfjs-converter/src/operations/op_list/logical.ts b/tfjs-converter/src/operations/op_list/logical.ts index 5b65f810519..dda89b7bebf 100644 --- a/tfjs-converter/src/operations/op_list/logical.ts +++ b/tfjs-converter/src/operations/op_list/logical.ts @@ -124,6 +124,18 @@ export const json: OpMapper[] = [ {'start': 1, 'name': 'a', 'type': 'tensor'}, {'start': 2, 'name': 'b', 'type': 'tensor'}, ], + 'attrs': [ + {'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true} + ] + }, + { + 'tfOpName': 'SelectV2', + 'category': 'logical', + 'inputs': [ + {'start': 0, 'name': 'condition', 'type': 'tensor'}, + {'start': 1, 'name': 'a', 'type': 'tensor'}, + {'start': 2, 'name': 'b', 'type': 'tensor'}, + ], 'attrs': [{ 'tfName': 'T', 'name': 'dtype',