Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions tfjs-converter/src/operations/executors/control_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {TensorArray} from '../../executor/tensor_array';
import {fromTensor, reserve, scatter, split} from '../../executor/tensor_list';
import {InternalOpAsyncExecutor, Node} from '../types';

import {getParamValue, getTensor} from './utils';
import {cloneTensor, getParamValue, getTensor} from './utils';

export const executeOp: InternalOpAsyncExecutor = async(
node: Node, tensorMap: NamedTensorsMap,
Expand Down Expand Up @@ -106,44 +106,48 @@ export const executeOp: InternalOpAsyncExecutor = async(
return result;
}
case 'LoopCond': {
return [
(getParamValue('pred', node, tensorMap, context) as tfc.Tensor).clone()
];
const pred =
getParamValue('pred', node, tensorMap, context) as tfc.Tensor;
return [cloneTensor(pred)];
}
case 'Switch': {
const pred =
getParamValue('pred', node, tensorMap, context) as tfc.Tensor;
const data =
getParamValue('data', node, tensorMap, context) as tfc.Tensor;
let data = getParamValue('data', node, tensorMap, context) as tfc.Tensor;
if (!data.kept) {
data = cloneTensor(data);
}
// Outputs nodes :0 => false, :1 => true
return (await pred.data())[0] ? [undefined, data.clone()] :
[data.clone(), undefined];
return (await pred.data())[0] ? [undefined, data] : [data, undefined];
}
case 'Merge': {
const inputName = node.inputNames.find(
name => getTensor(name, tensorMap, context) !== undefined);
return inputName ? [getTensor(inputName, tensorMap, context).clone()] :
undefined;
if (inputName) {
const data = getTensor(inputName, tensorMap, context);
return [cloneTensor(data)];
}
return undefined;
}
case 'Enter': {
const frameId =
getParamValue('frameName', node, tensorMap, context) as string;
const data =
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
context.enterFrame(frameId);
return [data.clone()];
return [cloneTensor(data)];
}
case 'Exit': {
const tensor =
const data =
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
context.exitFrame();
return [tensor.clone()];
return [cloneTensor(data)];
}
case 'NextIteration': {
const input =
const data =
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
context.nextIteration();
return [input.clone()];
return [cloneTensor(data)];
}
case 'TensorArrayV3': {
const size = getParamValue('size', node, tensorMap, context) as number;
Expand Down
20 changes: 10 additions & 10 deletions tfjs-converter/src/operations/executors/graph_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import {NamedTensorsMap} from '../../data/types';
import {ExecutionContext} from '../../executor/execution_context';
import {InternalOpExecutor, Node} from '../types';

import {getParamValue, getTensor} from './utils';
import {cloneTensor, getParamValue, getTensor} from './utils';

export const executeOp: InternalOpExecutor = (node: Node,
tensorMap: NamedTensorsMap,
context: ExecutionContext):
tfc.Tensor[] => {
tensorMap: NamedTensorsMap,
context: ExecutionContext):
tfc.Tensor[] => {
switch (node.op) {
case 'Const': {
return tensorMap[node.name];
Expand All @@ -39,17 +39,17 @@ export const executeOp: InternalOpExecutor = (node: Node,
return [getTensor(node.name, tensorMap, context)];
case 'Identity':
case 'StopGradient':
case 'FakeQuantWithMinMaxVars': // This op is currently ignored.
return [
(getParamValue('x', node, tensorMap, context) as tfc.Tensor).clone()
];
case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
const data = getParamValue('x', node, tensorMap, context) as tfc.Tensor;
return [cloneTensor(data)];
}
case 'IdentityN':
return (getParamValue('x', node, tensorMap, context) as tfc.Tensor[])
.map((t: tfc.Tensor) => t.clone());
.map((t: tfc.Tensor) => cloneTensor(t));
case 'Snapshot':
const snapshot =
(getParamValue('x', node, tensorMap, context) as tfc.Tensor);
return [snapshot.clone()];
return [cloneTensor(snapshot)];
case 'Shape':
return [tfc.tensor1d(
(getParamValue('x', node, tensorMap, context) as tfc.Tensor).shape,
Expand Down
13 changes: 13 additions & 0 deletions tfjs-converter/src/operations/executors/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,16 @@ export function getPadding(
}
return pad;
}

/**
* Reuse the tensor if it is marked as keep, otherwise clone the tensor to
* avoid disposal. This is important for TensorArray and TensorList ops, since
* internally they use a tensor as the id for TensorArray and TensorList, and
* to simplify lookup, they also use Tensor.id as the key to the internal map.
* These id tensors have been marked as kept in the backend, we need avoid clone
* them in order to create new Tensor.id.
* @param tensor
*/
export function cloneTensor(tensor: tfc.Tensor): tfc.Tensor {
return tensor.kept ? tensor : tfc.clone(tensor);
}