diff --git a/tfjs-converter/src/operations/executors/control_executor.ts b/tfjs-converter/src/operations/executors/control_executor.ts index 163b5e6853f..ab44e85ea29 100644 --- a/tfjs-converter/src/operations/executors/control_executor.ts +++ b/tfjs-converter/src/operations/executors/control_executor.ts @@ -21,13 +21,13 @@ import {scalar} from '@tensorflow/tfjs-core'; import {NamedTensorsMap} from '../../data/types'; import {ExecutionContext} from '../../executor/execution_context'; import {TensorArray} from '../../executor/tensor_array'; -import {Node} from '../types'; +import {InternalOpAsyncExecutor, Node} from '../types'; import {getParamValue, getTensor} from './utils'; -export async function executeOp( +export const executeOp: InternalOpAsyncExecutor = async( node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Promise { + context: ExecutionContext): Promise => { switch (node.op) { case 'LoopCond': return [ @@ -161,6 +161,6 @@ export async function executeOp( default: throw TypeError(`Node type ${node.op} is not implemented`); } -} +}; export const CATEGORY = 'control'; diff --git a/tfjs-converter/src/operations/executors/dynamic_executor.ts b/tfjs-converter/src/operations/executors/dynamic_executor.ts index 57ab29cc7aa..e6c0ab7c4f3 100644 --- a/tfjs-converter/src/operations/executors/dynamic_executor.ts +++ b/tfjs-converter/src/operations/executors/dynamic_executor.ts @@ -19,12 +19,13 @@ import * as tfc from '@tensorflow/tfjs-core'; import {NamedTensorsMap} from '../../data/types'; import {ExecutionContext} from '../../executor/execution_context'; -import {Node} from '../types'; +import {InternalOpAsyncExecutor, Node} from '../types'; + import {getParamValue} from './utils'; -export async function executeOp( +export const executeOp: InternalOpAsyncExecutor = async( node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Promise { + context: ExecutionContext): Promise => { switch (node.op) { case 'NonMaxSuppressionV5': case 'NonMaxSuppressionV3': @@ -56,9 +57,12 @@ export async function executeOp( iouThreshold, scoreThreshold)]; } case 'Where': { - return [await tfc.whereAsync( + const condition = (getParamValue('condition', node, tensorMap, context) as tfc.Tensor) - .asType('bool'))]; + .asType('bool'); + const result = [await tfc.whereAsync(condition)]; + condition.dispose(); + return result; } case 'ListDiff': { return tfc.setdiff1dAsync( @@ -68,6 +72,6 @@ export async function executeOp( default: throw TypeError(`Node type ${node.op} is not implemented`); } -} +}; export const CATEGORY = 'dynamic'; diff --git a/tfjs-converter/src/operations/executors/dynamic_executor_test.ts b/tfjs-converter/src/operations/executors/dynamic_executor_test.ts index 4663423178a..fdf573b58f9 100644 --- a/tfjs-converter/src/operations/executors/dynamic_executor_test.ts +++ b/tfjs-converter/src/operations/executors/dynamic_executor_test.ts @@ -171,6 +171,17 @@ describe('dynamic', () => { expect(validateParam(node, dynamic.json)).toBeTruthy(); }); + it('should not have memory leak', async () => { + node.op = 'Where'; + node.inputParams = {'condition': createTensorAttr(0)}; + const input1 = [tfc.scalar(1)]; + spyOn(tfc, 'whereAsync').and.callThrough(); + + const prevCount = tfc.memory().numTensors; + await executeOp(node, {input1}, context); + const afterCount = tfc.memory().numTensors; + expect(afterCount).toEqual(prevCount + 1); + }); }); describe('ListDiff', () => { diff --git a/tfjs-converter/src/operations/operation_executor.ts b/tfjs-converter/src/operations/operation_executor.ts index a4698a8dc8b..3f18b79ec59 100644 --- a/tfjs-converter/src/operations/operation_executor.ts +++ b/tfjs-converter/src/operations/operation_executor.ts @@ -52,37 +52,45 @@ export function executeOp( ((node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) => { switch (node.category) { case 'arithmetic': - return arithmetic.executeOp(node, tensorMap, context); + return tfc.tidy( + () => arithmetic.executeOp(node, tensorMap, context)); case 'basic_math': - return basicMath.executeOp(node, tensorMap, context); + return tfc.tidy( + () => basicMath.executeOp(node, tensorMap, context)); case 'control': return control.executeOp(node, tensorMap, context); case 'convolution': - return convolution.executeOp(node, tensorMap, context); + return tfc.tidy( + () => convolution.executeOp(node, tensorMap, context)); case 'creation': - return creation.executeOp(node, tensorMap, context); + return tfc.tidy(() => creation.executeOp(node, tensorMap, context)); case 'dynamic': return dynamic.executeOp(node, tensorMap, context); case 'evaluation': - return evaluation.executeOp(node, tensorMap, context); + return tfc.tidy( + () => evaluation.executeOp(node, tensorMap, context)); case 'image': - return image.executeOp(node, tensorMap, context); + return tfc.tidy(() => image.executeOp(node, tensorMap, context)); case 'graph': - return graph.executeOp(node, tensorMap, context); + return tfc.tidy(() => graph.executeOp(node, tensorMap, context)); case 'logical': - return logical.executeOp(node, tensorMap, context); + return tfc.tidy(() => logical.executeOp(node, tensorMap, context)); case 'matrices': - return matrices.executeOp(node, tensorMap, context); + return tfc.tidy(() => matrices.executeOp(node, tensorMap, context)); case 'normalization': - return normalization.executeOp(node, tensorMap, context); + return tfc.tidy( + () => normalization.executeOp(node, tensorMap, context)); case 'reduction': - return reduction.executeOp(node, tensorMap, context); + return tfc.tidy( + () => reduction.executeOp(node, tensorMap, context)); case 'slice_join': - return sliceJoin.executeOp(node, tensorMap, context); + return tfc.tidy( + () => sliceJoin.executeOp(node, tensorMap, context)); case 'spectral': - return spectral.executeOp(node, tensorMap, context); + return tfc.tidy(() => spectral.executeOp(node, tensorMap, context)); case 'transformation': - return transformation.executeOp(node, tensorMap, context); + return tfc.tidy( + () => transformation.executeOp(node, tensorMap, context)); case 'custom': const opMapper = getRegisteredOp(node.op); if (opMapper && opMapper.customExecutor) { diff --git a/tfjs-converter/src/operations/operation_executor_test.ts b/tfjs-converter/src/operations/operation_executor_test.ts index 21b0242129d..b3601893778 100644 --- a/tfjs-converter/src/operations/operation_executor_test.ts +++ b/tfjs-converter/src/operations/operation_executor_test.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import * as tfc from '@tensorflow/tfjs-core'; import {add, mul, scalar, Tensor, test_util} from '@tensorflow/tfjs-core'; import {ExecutionContext} from '../executor/execution_context'; @@ -22,6 +23,7 @@ import {ExecutionContext} from '../executor/execution_context'; import {deregisterOp, registerOp} from './custom_op/register'; import * as arithmetic from './executors/arithmetic_executor'; import * as basic_math from './executors/basic_math_executor'; +import * as control from './executors/control_executor'; import * as convolution from './executors/convolution_executor'; import * as creation from './executors/creation_executor'; import * as dynamic from './executors/dynamic_executor'; @@ -56,9 +58,9 @@ describe('OperationExecutor', () => { }); describe('executeOp', () => { - [arithmetic, basic_math, convolution, creation, dynamic, evaluation, image, - graph, logical, matrices, normalization, reduction, slice_join, spectral, - transformation] + [arithmetic, basic_math, convolution, control, creation, dynamic, + evaluation, image, graph, logical, matrices, normalization, reduction, + slice_join, spectral, transformation] .forEach(category => { it('should call ' + category.CATEGORY + ' executor', () => { spyOn(category, 'executeOp'); @@ -67,6 +69,17 @@ describe('OperationExecutor', () => { expect(category.executeOp).toHaveBeenCalledWith(node, {}, context); }); }); + [arithmetic, basic_math, convolution, creation, evaluation, image, graph, + logical, matrices, normalization, reduction, slice_join, spectral, + transformation] + .forEach(category => { + it('should call tidy around executor', () => { + spyOn(tfc, 'tidy'); + node.category = category.CATEGORY; + executeOp(node, {}, context); + expect(tfc.tidy).toHaveBeenCalled(); + }); + }); }); describe('custom op executeOp', () => { diff --git a/tfjs-converter/src/operations/types.ts b/tfjs-converter/src/operations/types.ts index 3a0bef0dfc4..cb5bf5ced69 100644 --- a/tfjs-converter/src/operations/types.ts +++ b/tfjs-converter/src/operations/types.ts @@ -75,12 +75,17 @@ export declare interface AttrParamMapper extends ParamMapper { export interface InternalOpExecutor { (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): Tensor - |Tensor[]|Promise; + |Tensor[]; +} + +export interface InternalOpAsyncExecutor { + (node: Node, tensorMap: NamedTensorsMap, + context: ExecutionContext): Promise; } export declare interface OpMapper { tfOpName: string; - category: Category; + category?: Category; inputs?: InputParamMapper[]; attrs?: AttrParamMapper[]; customExecutor?: OpExecutor;