From e95251e4cf3b39b01cb88a4c66ca12c04df81d23 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 19 Jun 2020 17:11:18 -0700 Subject: [PATCH 1/4] support negative split value --- .../executors/slice_join_executor.ts | 39 ++++++++++++------- .../executors/slice_join_executor_test.ts | 15 ++++++- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index 6187d1d9326..4c398bcab8e 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -24,9 +24,9 @@ import {InternalOpExecutor, Node} from '../types'; import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, - tensorMap: NamedTensorsMap, - context: ExecutionContext): - tfc.Tensor[] => { + tensorMap: NamedTensorsMap, + context: ExecutionContext): + tfc.Tensor[] => { switch (node.op) { case 'ConcatV2': case 'Concat': { @@ -66,9 +66,9 @@ export const executeOp: InternalOpExecutor = (node: Node, const end = getParamValue('end', node, tensorMap, context) as number[]; const strides = getParamValue('strides', node, tensorMap, context) as number[]; - const beginMask = + let beginMask = getParamValue('beginMask', node, tensorMap, context) as number; - const endMask = + let endMask = getParamValue('endMask', node, tensorMap, context) as number; const ellipsisMask = getParamValue('ellipsisMask', node, tensorMap, context) as number; @@ -77,13 +77,17 @@ export const executeOp: InternalOpExecutor = (node: Node, const shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context) as number; const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; - if (begin.length === 1 && tensor.shape.length > 1) { - for (let i = 1; i < tensor.shape.length; i++) { - begin.push(0); - end.push(tensor.shape[i]); - strides.push(strides[0]); + if (begin.length < tensor.rank) { + begin.unshift(0); + end.unshift(tensor.shape[0]); + if (end[1] === 0) { + end[1] = tensor.shape[1]; } + strides.unshift(1); + beginMask *= 2; + endMask *= 2; } + return [tfc.stridedSlice( tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)]; @@ -126,9 +130,18 @@ export const executeOp: InternalOpExecutor = (node: Node, const numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context) as number | number[]; - return tfc.split( - getParamValue('x', node, tensorMap, context) as tfc.Tensor, - numOrSizeSplits, axis); + const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; + + // Allow the last number of split array to be -1, which indicates the rest + // of dimension is allocated to the last split. + if (Array.isArray(numOrSizeSplits)) { + if (numOrSizeSplits[numOrSizeSplits.length - 1] === -1) { + const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); + numOrSizeSplits[numOrSizeSplits.length - 1] = + tensor.shape[axis] - total; + } + } + return tfc.split(tensor, numOrSizeSplits, axis); } case 'ScatterNd': { const indices = diff --git a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts index 16e71cf5d5b..b9e5c4d3449 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts @@ -21,7 +21,7 @@ import * as slice_join from '../op_list/slice_join'; import {Node} from '../types'; import {executeOp} from './slice_join_executor'; -import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; +import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttr, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; describe('slice join', () => { let node: Node; @@ -337,6 +337,19 @@ describe('slice join', () => { expect(tfc.split).toHaveBeenCalledWith(input2[0], 2, 1); }); + it('should support -1 split', () => { + spyOn(tfc, 'split'); + node.op = 'Split'; + node.inputParams.axis = createNumberAttrFromIndex(0); + node.inputParams.x = createTensorAttr(1); + node.attrParams.numOrSizeSplits = createNumericArrayAttr([1, 1, -1]); + node.inputNames = ['input4', 'input3']; + const input3 = [tfc.tensor1d([1, 2, 3, 4, 5])]; + const input4 = [tfc.scalar(0)]; + executeOp(node, {input3, input4}, context); + + expect(tfc.split).toHaveBeenCalledWith(input3[0], [1, 1, 3], 0); + }); it('should match json def for split', () => { node.op = 'Split'; node.inputParams.axis = createNumberAttrFromIndex(0); From ca104ac5747c16485aab02c5e3c80faf2a17b5c9 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 19 Jun 2020 17:39:45 -0700 Subject: [PATCH 2/4] revert changes --- .../executors/slice_join_executor.ts | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index 4c398bcab8e..97867fd5cc0 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -66,9 +66,9 @@ export const executeOp: InternalOpExecutor = (node: Node, const end = getParamValue('end', node, tensorMap, context) as number[]; const strides = getParamValue('strides', node, tensorMap, context) as number[]; - let beginMask = + const beginMask = getParamValue('beginMask', node, tensorMap, context) as number; - let endMask = + const endMask = getParamValue('endMask', node, tensorMap, context) as number; const ellipsisMask = getParamValue('ellipsisMask', node, tensorMap, context) as number; @@ -77,17 +77,13 @@ export const executeOp: InternalOpExecutor = (node: Node, const shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context) as number; const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; - if (begin.length < tensor.rank) { - begin.unshift(0); - end.unshift(tensor.shape[0]); - if (end[1] === 0) { - end[1] = tensor.shape[1]; + if (begin.length === 1 && tensor.shape.length > 1) { + for (let i = 1; i < tensor.shape.length; i++) { + begin.push(0); + end.push(tensor.shape[i]); + strides.push(strides[0]); } - strides.unshift(1); - beginMask *= 2; - endMask *= 2; } - return [tfc.stridedSlice( tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)]; From 2b728403eb50ad733dcfd21821d5fbab01f25e67 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Sun, 21 Jun 2020 16:21:20 -0700 Subject: [PATCH 3/4] allow one neg value for split param --- .../src/operations/executors/slice_join_executor.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index 97867fd5cc0..83348945534 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -131,10 +131,10 @@ export const executeOp: InternalOpExecutor = (node: Node, // Allow the last number of split array to be -1, which indicates the rest // of dimension is allocated to the last split. if (Array.isArray(numOrSizeSplits)) { - if (numOrSizeSplits[numOrSizeSplits.length - 1] === -1) { + const negIndex = numOrSizeSplits.indexOf(-1); + if (negIndex === -1) { const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); - numOrSizeSplits[numOrSizeSplits.length - 1] = - tensor.shape[axis] - total; + numOrSizeSplits[negIndex] = tensor.shape[axis] - total; } } return tfc.split(tensor, numOrSizeSplits, axis); From 15c227b9063b7427eb9b7c8a6209e2bf88a698d6 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Sun, 21 Jun 2020 16:28:55 -0700 Subject: [PATCH 4/4] fix the typo --- tfjs-converter/src/operations/executors/slice_join_executor.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index 83348945534..f2da33ea870 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -132,7 +132,7 @@ export const executeOp: InternalOpExecutor = (node: Node, // of dimension is allocated to the last split. if (Array.isArray(numOrSizeSplits)) { const negIndex = numOrSizeSplits.indexOf(-1); - if (negIndex === -1) { + if (negIndex !== -1) { const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); numOrSizeSplits[negIndex] = tensor.shape[axis] - total; }