diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index 6187d1d9326..f2da33ea870 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -24,9 +24,9 @@ import {InternalOpExecutor, Node} from '../types'; import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, - tensorMap: NamedTensorsMap, - context: ExecutionContext): - tfc.Tensor[] => { + tensorMap: NamedTensorsMap, + context: ExecutionContext): + tfc.Tensor[] => { switch (node.op) { case 'ConcatV2': case 'Concat': { @@ -126,9 +126,18 @@ export const executeOp: InternalOpExecutor = (node: Node, const numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context) as number | number[]; - return tfc.split( - getParamValue('x', node, tensorMap, context) as tfc.Tensor, - numOrSizeSplits, axis); + const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; + + // Allow the last number of split array to be -1, which indicates the rest + // of dimension is allocated to the last split. + if (Array.isArray(numOrSizeSplits)) { + const negIndex = numOrSizeSplits.indexOf(-1); + if (negIndex !== -1) { + const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); + numOrSizeSplits[negIndex] = tensor.shape[axis] - total; + } + } + return tfc.split(tensor, numOrSizeSplits, axis); } case 'ScatterNd': { const indices = diff --git a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts index 16e71cf5d5b..b9e5c4d3449 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts @@ -21,7 +21,7 @@ import * as slice_join from '../op_list/slice_join'; import {Node} from '../types'; import {executeOp} from './slice_join_executor'; -import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; +import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttr, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; describe('slice join', () => { let node: Node; @@ -337,6 +337,19 @@ describe('slice join', () => { expect(tfc.split).toHaveBeenCalledWith(input2[0], 2, 1); }); + it('should support -1 split', () => { + spyOn(tfc, 'split'); + node.op = 'Split'; + node.inputParams.axis = createNumberAttrFromIndex(0); + node.inputParams.x = createTensorAttr(1); + node.attrParams.numOrSizeSplits = createNumericArrayAttr([1, 1, -1]); + node.inputNames = ['input4', 'input3']; + const input3 = [tfc.tensor1d([1, 2, 3, 4, 5])]; + const input4 = [tfc.scalar(0)]; + executeOp(node, {input3, input4}, context); + + expect(tfc.split).toHaveBeenCalledWith(input3[0], [1, 1, 3], 0); + }); it('should match json def for split', () => { node.op = 'Split'; node.inputParams.axis = createNumberAttrFromIndex(0);