diff --git a/tfjs-converter/src/operations/executors/control_executor.ts b/tfjs-converter/src/operations/executors/control_executor.ts index 7effe6f6547..5decb09c4c2 100644 --- a/tfjs-converter/src/operations/executors/control_executor.ts +++ b/tfjs-converter/src/operations/executors/control_executor.ts @@ -24,7 +24,7 @@ import {TensorArray} from '../../executor/tensor_array'; import {fromTensor, reserve, scatter, split} from '../../executor/tensor_list'; import {InternalOpAsyncExecutor, Node} from '../types'; -import {getParamValue, getTensor} from './utils'; +import {cloneTensor, getParamValue, getTensor} from './utils'; export const executeOp: InternalOpAsyncExecutor = async( node: Node, tensorMap: NamedTensorsMap, @@ -106,24 +106,28 @@ export const executeOp: InternalOpAsyncExecutor = async( return result; } case 'LoopCond': { - return [ - (getParamValue('pred', node, tensorMap, context) as tfc.Tensor).clone() - ]; + const pred = + getParamValue('pred', node, tensorMap, context) as tfc.Tensor; + return [cloneTensor(pred)]; } case 'Switch': { const pred = getParamValue('pred', node, tensorMap, context) as tfc.Tensor; - const data = - getParamValue('data', node, tensorMap, context) as tfc.Tensor; + let data = getParamValue('data', node, tensorMap, context) as tfc.Tensor; + if (!data.kept) { + data = cloneTensor(data); + } // Outputs nodes :0 => false, :1 => true - return (await pred.data())[0] ? [undefined, data.clone()] : - [data.clone(), undefined]; + return (await pred.data())[0] ? [undefined, data] : [data, undefined]; } case 'Merge': { const inputName = node.inputNames.find( name => getTensor(name, tensorMap, context) !== undefined); - return inputName ? [getTensor(inputName, tensorMap, context).clone()] : - undefined; + if (inputName) { + const data = getTensor(inputName, tensorMap, context); + return [cloneTensor(data)]; + } + return undefined; } case 'Enter': { const frameId = @@ -131,19 +135,19 @@ export const executeOp: InternalOpAsyncExecutor = async( const data = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; context.enterFrame(frameId); - return [data.clone()]; + return [cloneTensor(data)]; } case 'Exit': { - const tensor = + const data = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; context.exitFrame(); - return [tensor.clone()]; + return [cloneTensor(data)]; } case 'NextIteration': { - const input = + const data = getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; context.nextIteration(); - return [input.clone()]; + return [cloneTensor(data)]; } case 'TensorArrayV3': { const size = getParamValue('size', node, tensorMap, context) as number; diff --git a/tfjs-converter/src/operations/executors/graph_executor.ts b/tfjs-converter/src/operations/executors/graph_executor.ts index c788577f89f..fc7b720fc60 100644 --- a/tfjs-converter/src/operations/executors/graph_executor.ts +++ b/tfjs-converter/src/operations/executors/graph_executor.ts @@ -21,12 +21,12 @@ import {NamedTensorsMap} from '../../data/types'; import {ExecutionContext} from '../../executor/execution_context'; import {InternalOpExecutor, Node} from '../types'; -import {getParamValue, getTensor} from './utils'; +import {cloneTensor, getParamValue, getTensor} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, - tensorMap: NamedTensorsMap, - context: ExecutionContext): - tfc.Tensor[] => { + tensorMap: NamedTensorsMap, + context: ExecutionContext): + tfc.Tensor[] => { switch (node.op) { case 'Const': { return tensorMap[node.name]; @@ -39,17 +39,17 @@ export const executeOp: InternalOpExecutor = (node: Node, return [getTensor(node.name, tensorMap, context)]; case 'Identity': case 'StopGradient': - case 'FakeQuantWithMinMaxVars': // This op is currently ignored. - return [ - (getParamValue('x', node, tensorMap, context) as tfc.Tensor).clone() - ]; + case 'FakeQuantWithMinMaxVars': { // This op is currently ignored. + const data = getParamValue('x', node, tensorMap, context) as tfc.Tensor; + return [cloneTensor(data)]; + } case 'IdentityN': return (getParamValue('x', node, tensorMap, context) as tfc.Tensor[]) - .map((t: tfc.Tensor) => t.clone()); + .map((t: tfc.Tensor) => cloneTensor(t)); case 'Snapshot': const snapshot = (getParamValue('x', node, tensorMap, context) as tfc.Tensor); - return [snapshot.clone()]; + return [cloneTensor(snapshot)]; case 'Shape': return [tfc.tensor1d( (getParamValue('x', node, tensorMap, context) as tfc.Tensor).shape, diff --git a/tfjs-converter/src/operations/executors/utils.ts b/tfjs-converter/src/operations/executors/utils.ts index 55af5eafd52..97fae50f1b8 100644 --- a/tfjs-converter/src/operations/executors/utils.ts +++ b/tfjs-converter/src/operations/executors/utils.ts @@ -136,3 +136,16 @@ export function getPadding( } return pad; } + +/** + * Reuse the tensor if it is marked as keep, otherwise clone the tensor to + * avoid disposal. This is important for TensorArray and TensorList ops, since + * internally they use a tensor as the id for TensorArray and TensorList, and + * to simplify lookup, they also use Tensor.id as the key to the internal map. + * These id tensors have been marked as kept in the backend, we need avoid clone + * them in order to create new Tensor.id. + * @param tensor + */ +export function cloneTensor(tensor: tfc.Tensor): tfc.Tensor { + return tensor.kept ? tensor : tfc.clone(tensor); +}