Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tfjs-converter/src/operations/executors/control_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import {scalar} from '@tensorflow/tfjs-core';
import {NamedTensorsMap} from '../../data/types';
import {ExecutionContext} from '../../executor/execution_context';
import {TensorArray} from '../../executor/tensor_array';
import {Node} from '../types';
import {InternalOpAsyncExecutor, Node} from '../types';

import {getParamValue, getTensor} from './utils';

export async function executeOp(
export const executeOp: InternalOpAsyncExecutor = async(
node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): Promise<tfc.Tensor[]> {
context: ExecutionContext): Promise<tfc.Tensor[]> => {
switch (node.op) {
case 'LoopCond':
return [
Expand Down Expand Up @@ -161,6 +161,6 @@ export async function executeOp(
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
}
};

export const CATEGORY = 'control';
16 changes: 10 additions & 6 deletions tfjs-converter/src/operations/executors/dynamic_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ import * as tfc from '@tensorflow/tfjs-core';

import {NamedTensorsMap} from '../../data/types';
import {ExecutionContext} from '../../executor/execution_context';
import {Node} from '../types';
import {InternalOpAsyncExecutor, Node} from '../types';

import {getParamValue} from './utils';

export async function executeOp(
export const executeOp: InternalOpAsyncExecutor = async(
node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): Promise<tfc.Tensor[]> {
context: ExecutionContext): Promise<tfc.Tensor[]> => {
switch (node.op) {
case 'NonMaxSuppressionV5':
case 'NonMaxSuppressionV3':
Expand Down Expand Up @@ -56,9 +57,12 @@ export async function executeOp(
iouThreshold, scoreThreshold)];
}
case 'Where': {
return [await tfc.whereAsync(
const condition =
(getParamValue('condition', node, tensorMap, context) as tfc.Tensor)
.asType('bool'))];
.asType('bool');
const result = [await tfc.whereAsync(condition)];
condition.dispose();
return result;
}
case 'ListDiff': {
return tfc.setdiff1dAsync(
Expand All @@ -68,6 +72,6 @@ export async function executeOp(
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
}
};

export const CATEGORY = 'dynamic';
11 changes: 11 additions & 0 deletions tfjs-converter/src/operations/executors/dynamic_executor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,17 @@ describe('dynamic', () => {

expect(validateParam(node, dynamic.json)).toBeTruthy();
});
it('should not have memory leak', async () => {
node.op = 'Where';
node.inputParams = {'condition': createTensorAttr(0)};
const input1 = [tfc.scalar(1)];
spyOn(tfc, 'whereAsync').and.callThrough();

const prevCount = tfc.memory().numTensors;
await executeOp(node, {input1}, context);
const afterCount = tfc.memory().numTensors;
expect(afterCount).toEqual(prevCount + 1);
});
});

describe('ListDiff', () => {
Expand Down
36 changes: 22 additions & 14 deletions tfjs-converter/src/operations/operation_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,37 +52,45 @@ export function executeOp(
((node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) => {
switch (node.category) {
case 'arithmetic':
return arithmetic.executeOp(node, tensorMap, context);
return tfc.tidy(
() => arithmetic.executeOp(node, tensorMap, context));
case 'basic_math':
return basicMath.executeOp(node, tensorMap, context);
return tfc.tidy(
() => basicMath.executeOp(node, tensorMap, context));
case 'control':
return control.executeOp(node, tensorMap, context);
case 'convolution':
return convolution.executeOp(node, tensorMap, context);
return tfc.tidy(
() => convolution.executeOp(node, tensorMap, context));
case 'creation':
return creation.executeOp(node, tensorMap, context);
return tfc.tidy(() => creation.executeOp(node, tensorMap, context));
case 'dynamic':
return dynamic.executeOp(node, tensorMap, context);
case 'evaluation':
return evaluation.executeOp(node, tensorMap, context);
return tfc.tidy(
() => evaluation.executeOp(node, tensorMap, context));
case 'image':
return image.executeOp(node, tensorMap, context);
return tfc.tidy(() => image.executeOp(node, tensorMap, context));
case 'graph':
return graph.executeOp(node, tensorMap, context);
return tfc.tidy(() => graph.executeOp(node, tensorMap, context));
case 'logical':
return logical.executeOp(node, tensorMap, context);
return tfc.tidy(() => logical.executeOp(node, tensorMap, context));
case 'matrices':
return matrices.executeOp(node, tensorMap, context);
return tfc.tidy(() => matrices.executeOp(node, tensorMap, context));
case 'normalization':
return normalization.executeOp(node, tensorMap, context);
return tfc.tidy(
() => normalization.executeOp(node, tensorMap, context));
case 'reduction':
return reduction.executeOp(node, tensorMap, context);
return tfc.tidy(
() => reduction.executeOp(node, tensorMap, context));
case 'slice_join':
return sliceJoin.executeOp(node, tensorMap, context);
return tfc.tidy(
() => sliceJoin.executeOp(node, tensorMap, context));
case 'spectral':
return spectral.executeOp(node, tensorMap, context);
return tfc.tidy(() => spectral.executeOp(node, tensorMap, context));
case 'transformation':
return transformation.executeOp(node, tensorMap, context);
return tfc.tidy(
() => transformation.executeOp(node, tensorMap, context));
case 'custom':
const opMapper = getRegisteredOp(node.op);
if (opMapper && opMapper.customExecutor) {
Expand Down
19 changes: 16 additions & 3 deletions tfjs-converter/src/operations/operation_executor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
* =============================================================================
*/

import * as tfc from '@tensorflow/tfjs-core';
import {add, mul, scalar, Tensor, test_util} from '@tensorflow/tfjs-core';

import {ExecutionContext} from '../executor/execution_context';

import {deregisterOp, registerOp} from './custom_op/register';
import * as arithmetic from './executors/arithmetic_executor';
import * as basic_math from './executors/basic_math_executor';
import * as control from './executors/control_executor';
import * as convolution from './executors/convolution_executor';
import * as creation from './executors/creation_executor';
import * as dynamic from './executors/dynamic_executor';
Expand Down Expand Up @@ -56,9 +58,9 @@ describe('OperationExecutor', () => {
});

describe('executeOp', () => {
[arithmetic, basic_math, convolution, creation, dynamic, evaluation, image,
graph, logical, matrices, normalization, reduction, slice_join, spectral,
transformation]
[arithmetic, basic_math, convolution, control, creation, dynamic,
evaluation, image, graph, logical, matrices, normalization, reduction,
slice_join, spectral, transformation]
.forEach(category => {
it('should call ' + category.CATEGORY + ' executor', () => {
spyOn(category, 'executeOp');
Expand All @@ -67,6 +69,17 @@ describe('OperationExecutor', () => {
expect(category.executeOp).toHaveBeenCalledWith(node, {}, context);
});
});
[arithmetic, basic_math, convolution, creation, evaluation, image, graph,
logical, matrices, normalization, reduction, slice_join, spectral,
transformation]
.forEach(category => {
it('should call tidy around executor', () => {
spyOn(tfc, 'tidy');
node.category = category.CATEGORY;
executeOp(node, {}, context);
expect(tfc.tidy).toHaveBeenCalled();
});
});
});

describe('custom op executeOp', () => {
Expand Down
9 changes: 7 additions & 2 deletions tfjs-converter/src/operations/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,17 @@ export declare interface AttrParamMapper extends ParamMapper {

export interface InternalOpExecutor {
(node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): Tensor
|Tensor[]|Promise<Tensor|Tensor[]>;
|Tensor[];
}

export interface InternalOpAsyncExecutor {
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): Promise<Tensor[]>;
}

export declare interface OpMapper {
tfOpName: string;
category: Category;
category?: Category;
inputs?: InputParamMapper[];
attrs?: AttrParamMapper[];
customExecutor?: OpExecutor;
Expand Down