From 331569ec86f1a91d09258ae75c300ea848bc5053 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 19 Jun 2020 13:05:13 -0700 Subject: [PATCH 1/3] maintain single id tensor for each tensor list and tensor array --- .../python/tensorflowjs/op_list/control.json | 30 ++--- tfjs-converter/src/executor/tensor_array.ts | 9 +- tfjs-converter/src/executor/tensor_list.ts | 25 ++-- .../operations/executors/control_executor.ts | 81 ++++++------ .../executors/control_executor_test.ts | 115 +++++++++--------- .../src/operations/op_list/control.ts | 30 ++--- 6 files changed, 144 insertions(+), 146 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/op_list/control.json b/tfjs-converter/python/tensorflowjs/op_list/control.json index 9fb66b84e9c..524a846bf9d 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/control.json +++ b/tfjs-converter/python/tensorflowjs/op_list/control.json @@ -155,7 +155,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -189,7 +189,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -218,7 +218,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -251,7 +251,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -284,7 +284,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -313,7 +313,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -346,7 +346,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -362,7 +362,7 @@ { "start": 0, "name": "tensorArrayId", - "type": "number" + "type": "tensor" } ] }, @@ -540,7 +540,7 @@ { "start": 0, "name": "tensorListId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -568,7 +568,7 @@ { "start": 0, "name": "tensorListId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -596,7 +596,7 @@ { "start": 0, "name": "tensorListId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -670,7 +670,7 @@ { "start": 0, "name": "tensorListId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -726,7 +726,7 @@ { "start": 0, "name": "tensorListId", - "type": "number" + "type": "tensor" } ], "attrs": [ @@ -749,7 +749,7 @@ { "start": 0, "name": "tensorListId", - "type": "number" + "type": "tensor" }, { "start": 1, @@ -772,7 +772,7 @@ { "start": 0, "name": "tensorListId", - "type": "number" + "type": "tensor" }, { "start": 1, diff --git a/tfjs-converter/src/executor/tensor_array.ts b/tfjs-converter/src/executor/tensor_array.ts index f9b65343c73..8b09681c27c 100644 --- a/tfjs-converter/src/executor/tensor_array.ts +++ b/tfjs-converter/src/executor/tensor_array.ts @@ -30,20 +30,21 @@ export interface TensorWithState { * allows reading from the array and writing to the array. */ export class TensorArray { - private static nextId = 0; private tensors: TensorWithState[] = []; private closed_ = false; - readonly id: number; readonly idTensor: Tensor; constructor( readonly name: string, readonly dtype: DataType, private maxSize: number, private elementShape: number[], readonly identicalElementShapes: boolean, readonly dynamicSize: boolean, readonly clearAfterRead: boolean) { - this.id = TensorArray.nextId++; - this.idTensor = scalar(this.id); + this.idTensor = scalar(0); keep(this.idTensor); } + get id() { + return this.idTensor.id; + } + get closed() { return this.closed_; } diff --git a/tfjs-converter/src/executor/tensor_list.ts b/tfjs-converter/src/executor/tensor_list.ts index 8e6abcc6ee3..f3b4900d821 100644 --- a/tfjs-converter/src/executor/tensor_list.ts +++ b/tfjs-converter/src/executor/tensor_list.ts @@ -35,11 +35,12 @@ import {assertShapesMatchAllowUndefinedSize} from './tensor_utils'; */ export class TensorList { - private static nextId = 0; - readonly id: number; readonly idTensor: Tensor; maxNumElements: number; + get id() { + return this.idTensor.id; + } /** * * @param tensors list of tensors @@ -52,8 +53,7 @@ export class TensorList { readonly tensors: Tensor[], readonly elementShape: number[], readonly elementDtype: DataType, maxNumElements = -1) { tensors.forEach(tensor => keep(tensor)); - this.id = TensorList.nextId++; - this.idTensor = scalar(this.id); + this.idTensor = scalar(0); this.maxNumElements = maxNumElements; keep(this.idTensor); } @@ -101,11 +101,11 @@ export class TensorList { } assertShapesMatchAllowUndefinedSize( elementShape, this.elementShape, 'TensorList shape mismatch: '); - return tidy(() => { - const reshapedTensors = - this.tensors.map(tensor => tensor.reshape(elementShape)); - return stack(reshapedTensors, 0); - }); + // return tidy(() => { + // const reshapedTensors = + // this.tensors.map(tensor => tensor.reshape(elementShape)); + return stack(this.tensors, 0); + // }); } /** @@ -293,12 +293,7 @@ export function fromTensor( assertShapesMatchAllowUndefinedSize( outputShape, elementShape, 'TensorList shape mismatch: '); - const tensorList: Tensor[] = []; - for (let i = 0; i < tensor.shape[0]; ++i) { - const tmp = tensor.slice(i, 1); - tensorList.push(tmp.reshape(outputShape)); - tmp.dispose(); - } + const tensorList: Tensor[] = tensor.unstack(); return new TensorList(tensorList, elementShape, dtype); } diff --git a/tfjs-converter/src/operations/executors/control_executor.ts b/tfjs-converter/src/operations/executors/control_executor.ts index c44981c77de..7effe6f6547 100644 --- a/tfjs-converter/src/operations/executors/control_executor.ts +++ b/tfjs-converter/src/operations/executors/control_executor.ts @@ -166,89 +166,94 @@ export const executeOp: InternalOpAsyncExecutor = async( return [tensorArray.idTensor, scalar(1.0)]; } case 'TensorArrayWriteV3': { - const id = - getParamValue('tensorArrayId', node, tensorMap, context) as number; + const id = getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; const index = getParamValue('index', node, tensorMap, context) as number; const writeTensor = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; - const writeTensorArray = context.getTensorArray(id); + const writeTensorArray = context.getTensorArray(id.id); writeTensorArray.write(index, writeTensor); return [writeTensorArray.idTensor]; } case 'TensorArrayReadV3': { - const readId = - getParamValue('tensorArrayId', node, tensorMap, context) as number; + const readId = getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; const readIndex = getParamValue('index', node, tensorMap, context) as number; - const readTensorArray = context.getTensorArray(readId); + const readTensorArray = context.getTensorArray(readId.id); return [readTensorArray.read(readIndex)]; } case 'TensorArrayGatherV3': { const gatherId = - getParamValue('tensorArrayId', node, tensorMap, context) as number; + getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; const gatherIndices = getParamValue('indices', node, tensorMap, context) as number[]; const gatherDtype = getParamValue('dtype', node, tensorMap, context) as tfc.DataType; - const gatherTensorArray = context.getTensorArray(gatherId); + const gatherTensorArray = context.getTensorArray(gatherId.id); return [gatherTensorArray.gather(gatherIndices, gatherDtype)]; } case 'TensorArrayScatterV3': { const scatterId = - getParamValue('tensorArrayId', node, tensorMap, context) as number; + getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; const scatterIndices = getParamValue('indices', node, tensorMap, context) as number[]; const scatterTensor = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; - const scatterTensorArray = context.getTensorArray(scatterId); + const scatterTensorArray = context.getTensorArray(scatterId.id); scatterTensorArray.scatter(scatterIndices, scatterTensor); return [scatterTensorArray.idTensor]; } case 'TensorArrayConcatV3': { const concatId = - getParamValue('tensorArrayId', node, tensorMap, context) as number; - const concatTensorArray = context.getTensorArray(concatId); + getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; + const concatTensorArray = context.getTensorArray(concatId.id); const concatDtype = getParamValue('dtype', node, tensorMap, context) as tfc.DataType; return [concatTensorArray.concat(concatDtype)]; } case 'TensorArraySplitV3': { const splitId = - getParamValue('tensorArrayId', node, tensorMap, context) as number; + getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; const splitTensor = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; const lengths = getParamValue('lengths', node, tensorMap, context) as number[]; - const splitTensorArray = context.getTensorArray(splitId); + const splitTensorArray = context.getTensorArray(splitId.id); splitTensorArray.split(lengths, splitTensor); return [splitTensorArray.idTensor]; } case 'TensorArraySizeV3': { - const sizeId = - getParamValue('tensorArrayId', node, tensorMap, context) as number; - const sizeTensorArray = context.getTensorArray(sizeId); + const sizeId = getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; + const sizeTensorArray = context.getTensorArray(sizeId.id); return [scalar(sizeTensorArray.size(), 'int32')]; } case 'TensorArrayCloseV3': { const closeId = - getParamValue('tensorArrayId', node, tensorMap, context) as number; - const closeTensorArray = context.getTensorArray(closeId); + getParamValue('tensorArrayId', node, tensorMap, context) as + tfc.Tensor; + const closeTensorArray = context.getTensorArray(closeId.id); closeTensorArray.clearAndClose(); return [closeTensorArray.idTensor]; } case 'TensorListSetItem': { - const id = - getParamValue('tensorListId', node, tensorMap, context) as number; + const idTensor = + getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; const index = getParamValue('index', node, tensorMap, context) as number; const writeTensor = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; - const tensorList = context.getTensorList(id); + const tensorList = context.getTensorList(idTensor.id); tensorList.setItem(index, writeTensor); return [tensorList.idTensor]; } case 'TensorListGetItem': { - const readId = - getParamValue('tensorListId', node, tensorMap, context) as number; + const idTensor = + getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; const readIndex = getParamValue('index', node, tensorMap, context) as number; const elementShape = @@ -257,7 +262,7 @@ export const executeOp: InternalOpAsyncExecutor = async( const elementDType = getParamValue('elementDType', node, tensorMap, context) as tfc.DataType; - const tensorList = context.getTensorList(readId); + const tensorList = context.getTensorList(idTensor.id); return [tensorList.getItem(readIndex, elementShape, elementDType)]; } case 'TensorListScatterV2': @@ -289,7 +294,7 @@ export const executeOp: InternalOpAsyncExecutor = async( } case 'TensorListGather': { const gatherId = - getParamValue('tensorListId', node, tensorMap, context) as number; + getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; const gatherIndices = getParamValue('indices', node, tensorMap, context) as number[]; const elementShape = @@ -297,12 +302,12 @@ export const executeOp: InternalOpAsyncExecutor = async( const elementDtype = getParamValue('elementDType', node, tensorMap, context) as tfc.DataType; - const tensorList = context.getTensorList(gatherId); + const tensorList = context.getTensorList(gatherId.id); return [tensorList.gather(gatherIndices, elementDtype, elementShape)]; } case 'TensorListStack': { - const gatherId = - getParamValue('tensorListId', node, tensorMap, context) as number; + const idTensor = + getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; const elementShape = getParamValue('elementShape', node, tensorMap, context) as number[]; const elementDtype = @@ -310,7 +315,7 @@ export const executeOp: InternalOpAsyncExecutor = async( tfc.DataType; const numElements = getParamValue('numElements', node, tensorMap, context) as number; - const tensorList = context.getTensorList(gatherId); + const tensorList = context.getTensorList(idTensor.id); return [tensorList.stack(elementShape, elementDtype, numElements)]; } case 'TensorListFromTensor': { @@ -327,8 +332,8 @@ export const executeOp: InternalOpAsyncExecutor = async( } case 'TensorListConcat': { const concatId = - getParamValue('tensorListId', node, tensorMap, context) as number; - const tensorList = context.getTensorList(concatId); + getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; + const tensorList = context.getTensorList(concatId.id); const concatDtype = getParamValue('dtype', node, tensorMap, context) as tfc.DataType; const elementShape = @@ -336,23 +341,23 @@ export const executeOp: InternalOpAsyncExecutor = async( return [tensorList.concat(concatDtype, elementShape)]; } case 'TensorListPushBack': { - const id = - getParamValue('tensorListId', node, tensorMap, context) as number; + const idTensor = + getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; const writeTensor = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; - const tensorList = context.getTensorList(id); + const tensorList = context.getTensorList(idTensor.id); tensorList.pushBack(writeTensor); return [tensorList.idTensor]; } case 'TensorListPopBack': { - const readId = - getParamValue('tensorListId', node, tensorMap, context) as number; + const idTensor = + getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; const elementShape = getParamValue('elementShape', node, tensorMap, context) as number[]; const elementDType = getParamValue('elementDType', node, tensorMap, context) as tfc.DataType; - const tensorList = context.getTensorList(readId); + const tensorList = context.getTensorList(idTensor.id); return [tensorList.popBack(elementShape, elementDType)]; } case 'TensorListSplit': { diff --git a/tfjs-converter/src/operations/executors/control_executor_test.ts b/tfjs-converter/src/operations/executors/control_executor_test.ts index 0c27eda354f..711ecc5c83b 100644 --- a/tfjs-converter/src/operations/executors/control_executor_test.ts +++ b/tfjs-converter/src/operations/executors/control_executor_test.ts @@ -185,9 +185,8 @@ describe('control', () => { node.attrParams['identicalElementShapes'] = createBoolAttr(true); node.inputNames = ['input1']; - const tensorId = - (await executeOp(node, {input1}, context))[0].dataSync()[0]; - expect(context.getTensorArray(tensorId)).toBeDefined(); + const tensorId = (await executeOp(node, {input1}, context))[0]; + expect(context.getTensorArray(tensorId.id)).toBeDefined(); }); it('should match json def', () => { node.op = 'TensorArrayV3'; @@ -209,11 +208,11 @@ describe('control', () => { new TensorArray('', 'int32', 5, [], true, false, true); context.addTensorArray(tensorArray); node.op = 'TensorArrayWriteV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); node.inputParams['tensor'] = createTensorAttr(2); node.inputNames = ['input2', 'input3', 'input1']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; const input3 = [scalar(0)]; await executeOp(node, {input1, input2, input3}, context); @@ -221,7 +220,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorArrayWriteV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); node.inputParams['tensor'] = createTensorAttr(2); @@ -237,10 +236,10 @@ describe('control', () => { tensorArray.write(0, input4); context.addTensorArray(tensorArray); node.op = 'TensorArrayReadV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); node.inputNames = ['input2', 'input3']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; const input3 = [scalar(0)]; const read = await executeOp(node, {input1, input2, input3}, context); @@ -249,7 +248,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorArrayReadV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); expect(validateParam(node, control.json)).toBeTruthy(); @@ -265,11 +264,11 @@ describe('control', () => { tensorArray.writeMany([0, 1], [input4, input5]); context.addTensorArray(tensorArray); node.op = 'TensorArrayGatherV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['indices'] = createNumericArrayAttrFromIndex(1); node.attrParams['dtype'] = createDtypeAttr('int32'); node.inputNames = ['input2', 'input3']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; const input3 = [tensor1d([0, 1])]; const gather = await executeOp(node, {input2, input3}, context); expect(gather.length).toEqual(1); @@ -279,7 +278,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorArrayGatherV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['indices'] = createNumericArrayAttrFromIndex(1); node.attrParams['dtype'] = createDtypeAttr('int32'); @@ -294,11 +293,11 @@ describe('control', () => { const input4 = [tensor2d([0, 0, 0, 1, 1, 1], [2, 3], 'int32')]; context.addTensorArray(tensorArray); node.op = 'TensorArrayScatterV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['indices'] = createNumericArrayAttrFromIndex(1); node.inputParams['tensor'] = createTensorAttr(2); node.inputNames = ['input2', 'input3', 'input4']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; const input3 = [tensor1d([0, 1], 'int32')]; await executeOp(node, {input2, input3, input4}, context); @@ -307,7 +306,7 @@ describe('control', () => { it('should match json def', () => { node.op = 'TensorArrayScatterV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['indices'] = createNumericArrayAttrFromIndex(1); node.inputParams['tensor'] = createTensorAttr(2); @@ -322,11 +321,11 @@ describe('control', () => { const input4 = [tensor2d([0, 0, 0, 1, 1, 1], [2, 3], 'int32')]; context.addTensorArray(tensorArray); node.op = 'TensorArraySplitV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['tensor'] = createTensorAttr(1); node.inputParams['lengths'] = createNumericArrayAttrFromIndex(2); node.inputNames = ['input2', 'input4', 'input3']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; const input3 = [tensor1d([1, 1], 'int32')]; await executeOp(node, {input2, input3, input4}, context); @@ -334,7 +333,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorArraySplitV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputParams['tensor'] = createTensorAttr(1); node.inputParams['lengths'] = createNumericArrayAttrFromIndex(2); @@ -351,10 +350,10 @@ describe('control', () => { tensorArray.writeMany([0, 1], [input4, input5]); context.addTensorArray(tensorArray); node.op = 'TensorArrayConcatV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.attrParams['dtype'] = createDtypeAttr('int32'); node.inputNames = ['input2']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; const concat = await executeOp(node, {input2}, context); expect(concat.length).toEqual(1); expect(concat[0].shape).toEqual([6]); @@ -363,7 +362,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorArrayConcatV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.attrParams['dtype'] = createDtypeAttr('int32'); expect(validateParam(node, control.json)).toBeTruthy(); @@ -379,9 +378,9 @@ describe('control', () => { tensorArray.writeMany([0, 1], [input4, input5]); context.addTensorArray(tensorArray); node.op = 'TensorArraySizeV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputNames = ['input2']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; const size = await executeOp(node, {input2}, context); expect(size.length).toEqual(1); expect(size[0].shape).toEqual([]); @@ -389,7 +388,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorArraySizeV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); expect(validateParam(node, control.json)).toBeTruthy(); }); @@ -404,15 +403,15 @@ describe('control', () => { tensorArray.writeMany([0, 1], [input4, input5]); context.addTensorArray(tensorArray); node.op = 'TensorArrayCloseV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); node.inputNames = ['input2']; - const input2 = [scalar(tensorArray.id)]; + const input2 = [tensorArray.idTensor]; await executeOp(node, {input2}, context); expect(tensorArray.closed).toBeTruthy(); }); it('should match json def', () => { node.op = 'TensorArrayCloseV3'; - node.inputParams['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorArrayId'] = createTensorAttr(0); expect(validateParam(node, control.json)).toBeTruthy(); }); @@ -675,8 +674,8 @@ describe('control', () => { node.inputNames = ['input4', 'input1']; const input4 = [tensor1d([10, 10], 'int32')]; const tensorListId = - (await executeOp(node, {input1, input4}, context))[0].dataSync()[0]; - const tensorList = context.getTensorList(tensorListId); + (await executeOp(node, {input1, input4}, context))[0]; + const tensorList = context.getTensorList(tensorListId.id); expect(tensorList.elementDtype).toEqual('int32'); expect(tensorList.elementShape).toEqual([10, 10]); expect(tensorList.maxNumElements).toEqual(1); @@ -698,11 +697,11 @@ describe('control', () => { const tensorList = new TensorList([input4, input5], [3], 'int32', 5); context.addTensorList(tensorList); node.op = 'TensorListConcat'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.attrParams['elementDType'] = createDtypeAttr('int32'); node.attrParams['elementShape'] = createTensorShapeAttr([3]); node.inputNames = ['input2']; - const input2 = [scalar(tensorList.id)]; + const input2 = [tensorList.idTensor]; const concat = await executeOp(node, {input2}, context); expect(concat.length).toEqual(1); expect(concat[0].shape).toEqual([6]); @@ -711,7 +710,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorListConcat'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.attrParams['elementDType'] = createDtypeAttr('int32'); node.attrParams['elementShape'] = createTensorShapeAttr([3]); @@ -728,10 +727,9 @@ describe('control', () => { node.inputNames = ['input4', 'input2', 'input3']; const input2 = [tensor1d([0, 1], 'int32')]; const input3 = [tensor1d([3], 'int32')]; - const tensorListId = (await executeOp( - node, {input2, input3, input4}, - context))[0].dataSync()[0]; - const tensorList = context.getTensorList(tensorListId); + const tensorListId = + (await executeOp(node, {input2, input3, input4}, context))[0]; + const tensorList = context.getTensorList(tensorListId.id); expect(tensorList.size()).toEqual(2); }); @@ -758,9 +756,8 @@ describe('control', () => { const input3 = [tensor1d([3], 'int32')]; const input5 = [tensor1d([2], 'int32')]; const tensorListId = (await executeOp( - node, {input2, input3, input4, input5}, - context))[0].dataSync()[0]; - const tensorList = context.getTensorList(tensorListId); + node, {input2, input3, input4, input5}, context))[0]; + const tensorList = context.getTensorList(tensorListId.id); expect(tensorList.size()).toEqual(2); }); @@ -779,12 +776,12 @@ describe('control', () => { const tensorList = new TensorList([], [], 'int32', 5); context.addTensorList(tensorList); node.op = 'TensorListSetItem'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); node.inputParams['tensor'] = createTensorAttr(2); node.attrParams['elementDType'] = createDtypeAttr('int32'); node.inputNames = ['input2', 'input3', 'input1']; - const input2 = [scalar(tensorList.id)]; + const input2 = [tensorList.idTensor]; const input3 = [scalar(0)]; await executeOp(node, {input1, input2, input3}, context); @@ -792,7 +789,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorListSetItem'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); node.inputParams['tensor'] = createTensorAttr(2); node.attrParams['elementDType'] = createDtypeAttr('int32'); @@ -808,12 +805,12 @@ describe('control', () => { tensorList.setItem(0, input4); context.addTensorList(tensorList); node.op = 'TensorListGetItem'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); node.inputParams['elementShape'] = createShapeAttrFromIndex(2); node.inputNames = ['input2', 'input3', 'input5']; node.attrParams['elementDType'] = createDtypeAttr('int32'); - const input2 = [scalar(tensorList.id)]; + const input2 = [tensorList.idTensor]; const input3 = [scalar(0)]; const input5 = [tensor1d([3], 'int32')]; const read = await executeOp(node, {input5, input2, input3}, context); @@ -823,7 +820,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorListGetItem'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['index'] = createNumberAttrFromIndex(1); node.inputParams['elementShape'] = createShapeAttrFromIndex(2); node.attrParams['elementDType'] = createDtypeAttr('int32'); @@ -836,18 +833,18 @@ describe('control', () => { const tensorList = new TensorList([], [], 'int32', 5); context.addTensorList(tensorList); node.op = 'TensorListPushBack'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['tensor'] = createTensorAttr(1); node.attrParams['elementDType'] = createDtypeAttr('int32'); node.inputNames = ['input2', 'input1']; - const input2 = [scalar(tensorList.id)]; + const input2 = [tensorList.idTensor]; await executeOp(node, {input1, input2}, context); expect(tensorList.size()).toEqual(1); }); it('should match json def', () => { node.op = 'TensorListPushBack'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['tensor'] = createTensorAttr(1); node.attrParams['elementDType'] = createDtypeAttr('int32'); @@ -862,11 +859,11 @@ describe('control', () => { tensorList.setItem(0, input4); context.addTensorList(tensorList); node.op = 'TensorListPopBack'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['elementShape'] = createShapeAttrFromIndex(1); node.inputNames = ['input2', 'input5']; node.attrParams['elementDType'] = createDtypeAttr('int32'); - const input2 = [scalar(tensorList.id)]; + const input2 = [tensorList.idTensor]; const input5 = [tensor1d([3], 'int32')]; const read = await executeOp(node, {input5, input2}, context); @@ -875,7 +872,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorListPopBack'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['elementShape'] = createShapeAttrFromIndex(1); node.attrParams['elementDType'] = createDtypeAttr('int32'); @@ -889,11 +886,11 @@ describe('control', () => { tensorList.setItem(0, input4); context.addTensorList(tensorList); node.op = 'TensorListStack'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['elementShape'] = createShapeAttrFromIndex(1); node.inputNames = ['input2', 'input5']; node.attrParams['elementDType'] = createDtypeAttr('int32'); - const input2 = [scalar(tensorList.id)]; + const input2 = [tensorList.idTensor]; const input5 = [tensor1d([3], 'int32')]; const read = await executeOp(node, {input5, input2}, context); @@ -902,7 +899,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorListStack'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['elementShape'] = createShapeAttrFromIndex(1); node.attrParams['elementDType'] = createDtypeAttr('int32'); @@ -917,12 +914,12 @@ describe('control', () => { tensorList.setItem(1, input6); context.addTensorList(tensorList); node.op = 'TensorListGather'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['indices'] = createNumericArrayAttrFromIndex(1); node.inputParams['elementShape'] = createShapeAttrFromIndex(2); node.inputNames = ['input2', 'input3', 'input5']; node.attrParams['elementDType'] = createDtypeAttr('int32'); - const input2 = [scalar(tensorList.id)]; + const input2 = [tensorList.idTensor]; const input3 = [tensor1d([0, 1], 'int32')]; const input5 = [tensor1d([3], 'int32')]; @@ -934,7 +931,7 @@ describe('control', () => { }); it('should match json def', () => { node.op = 'TensorListGather'; - node.inputParams['tensorListId'] = createNumberAttrFromIndex(0); + node.inputParams['tensorListId'] = createTensorAttr(0); node.inputParams['indices'] = createNumericArrayAttrFromIndex(1); node.inputParams['elementShape'] = createShapeAttrFromIndex(2); node.attrParams['elementDType'] = createDtypeAttr('int32'); @@ -956,8 +953,8 @@ describe('control', () => { const input2 = [tensor1d([3], 'int32')]; const input3 = [tensor1d([1, 1], 'int32')]; const idTensor = - await executeOp(node, {input2, input3, input4}, context); - const tensorList = context.getTensorList(idTensor[0].dataSync()[0]); + (await executeOp(node, {input2, input3, input4}, context))[0]; + const tensorList = context.getTensorList(idTensor.id); expect(tensorList.size()).toEqual(2); }); diff --git a/tfjs-converter/src/operations/op_list/control.ts b/tfjs-converter/src/operations/op_list/control.ts index 040a0ed6f0c..f1c3d3dfb71 100644 --- a/tfjs-converter/src/operations/op_list/control.ts +++ b/tfjs-converter/src/operations/op_list/control.ts @@ -91,7 +91,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorArrayWriteV3', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, + {'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, {'start': 1, 'name': 'index', 'type': 'number'}, {'start': 2, 'name': 'tensor', 'type': 'tensor'}, {'start': 3, 'name': 'flowIn', 'type': 'number'}, @@ -104,7 +104,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorArrayReadV3', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, + {'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, {'start': 1, 'name': 'index', 'type': 'number'}, {'start': 2, 'name': 'flowIn', 'type': 'number'}, ], @@ -119,7 +119,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorArrayGatherV3', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, + {'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, {'start': 1, 'name': 'indices', 'type': 'number[]'}, {'start': 2, 'name': 'flowIn', 'type': 'number'}, ], @@ -132,7 +132,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorArrayScatterV3', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, + {'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, {'start': 1, 'name': 'indices', 'type': 'number[]'}, {'start': 2, 'name': 'tensor', 'type': 'tensor'}, {'start': 3, 'name': 'flowIn', 'type': 'number'}, @@ -143,7 +143,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorArrayConcatV3', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, + {'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, {'start': 1, 'name': 'flowIn', 'type': 'number'}, ], 'attrs': [ @@ -159,7 +159,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorArraySplitV3', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, + {'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, {'start': 1, 'name': 'tensor', 'type': 'tensor'}, {'start': 2, 'name': 'lengths', 'type': 'number[]'}, {'start': 3, 'name': 'flowIn', 'type': 'number'}, @@ -170,14 +170,14 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorArraySizeV3', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, + {'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, {'start': 1, 'name': 'flowIn', 'type': 'number'} ] }, { 'tfOpName': 'TensorArrayCloseV3', 'category': 'control', - 'inputs': [{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}] + 'inputs': [{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}] }, { 'tfOpName': 'StatelessIf', @@ -252,7 +252,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorListGather', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorListId', 'type': 'number'}, + {'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, {'start': 1, 'name': 'indices', 'type': 'number[]'}, {'start': 2, 'name': 'elementShape', 'type': 'shape'}, ], @@ -263,7 +263,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorListGetItem', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorListId', 'type': 'number'}, + {'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, {'start': 1, 'name': 'index', 'type': 'number'}, {'start': 2, 'name': 'elementShape', 'type': 'shape'}, ], @@ -274,7 +274,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorListSetItem', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorListId', 'type': 'number'}, + {'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, {'start': 1, 'name': 'index', 'type': 'number'}, {'start': 2, 'name': 'tensor', 'type': 'tensor'}, ], @@ -305,7 +305,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorListStack', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorListId', 'type': 'number'}, + {'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, {'start': 1, 'name': 'elementShape', 'type': 'shape'}, ], 'attrs': [ @@ -328,7 +328,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorListConcat', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorListId', 'type': 'number'}, + {'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, ], 'attrs': [ {'tfName': 'element_shape', 'name': 'elementShape', 'type': 'shape'}, @@ -339,7 +339,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorListPopBack', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorListId', 'type': 'number'}, + {'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, {'start': 1, 'name': 'elementShape', 'type': 'shape'}, ], 'attrs': @@ -349,7 +349,7 @@ export const json: OpMapper[] = [ 'tfOpName': 'TensorListPushBack', 'category': 'control', 'inputs': [ - {'start': 0, 'name': 'tensorListId', 'type': 'number'}, + {'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, {'start': 1, 'name': 'tensor', 'type': 'tensor'}, ], 'attrs': From 3c4566f3d321b5be5cdffdb33df8043da369da7f Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 19 Jun 2020 14:53:08 -0700 Subject: [PATCH 2/3] revert the change --- tfjs-converter/src/executor/tensor_list.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-converter/src/executor/tensor_list.ts b/tfjs-converter/src/executor/tensor_list.ts index f3b4900d821..eca03429f89 100644 --- a/tfjs-converter/src/executor/tensor_list.ts +++ b/tfjs-converter/src/executor/tensor_list.ts @@ -101,11 +101,11 @@ export class TensorList { } assertShapesMatchAllowUndefinedSize( elementShape, this.elementShape, 'TensorList shape mismatch: '); - // return tidy(() => { - // const reshapedTensors = - // this.tensors.map(tensor => tensor.reshape(elementShape)); - return stack(this.tensors, 0); - // }); + return tidy(() => { + const reshapedTensors = + this.tensors.map(tensor => tensor.reshape(elementShape)); + return stack(reshapedTensors, 0); + }); } /** From 3355079753e4a6ec59479a31051ced9bc2b37087 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 19 Jun 2020 15:42:24 -0700 Subject: [PATCH 3/3] add dtype and shape check for constructor --- tfjs-converter/src/executor/tensor_list.ts | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tfjs-converter/src/executor/tensor_list.ts b/tfjs-converter/src/executor/tensor_list.ts index eca03429f89..ea3e26e37d7 100644 --- a/tfjs-converter/src/executor/tensor_list.ts +++ b/tfjs-converter/src/executor/tensor_list.ts @@ -52,7 +52,18 @@ export class TensorList { constructor( readonly tensors: Tensor[], readonly elementShape: number[], readonly elementDtype: DataType, maxNumElements = -1) { - tensors.forEach(tensor => keep(tensor)); + if (tensors != null) { + tensors.forEach(tensor => { + if (elementDtype !== tensor.dtype) { + throw new Error(`Invalid data types; op elements ${ + elementDtype}, but list elements ${tensor.dtype}`); + } + assertShapesMatchAllowUndefinedSize( + elementShape, tensor.shape, 'TensorList shape mismatch: '); + + keep(tensor); + }); + } this.idTensor = scalar(0); this.maxNumElements = maxNumElements; keep(this.idTensor);