Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions e2e/integration_tests/backends_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import '@tensorflow/tfjs-backend-cpu';
import '@tensorflow/tfjs-backend-webgl';

import * as tfc from '@tensorflow/tfjs-core';

import {SMOKE} from './constants';

/**
* This file tests backend switching scenario.
*/
describe(`${SMOKE} backends`, () => {
describe('switch', () => {
beforeAll(() => {
tfc.env().set('WEBGL_CPU_FORWARD', false);
});

it(`from webgl to cpu.`, async () => {
await tfc.setBackend('webgl');

const webglBefore = tfc.engine().backend.numDataIds();

const input = tfc.tensor2d([1, 1, 1, 1], [2, 2], 'float32');
// input is stored in webgl backend.

const inputReshaped = tfc.reshape(input, [2, 2]);

const webglAfter = tfc.engine().backend.numDataIds();

expect(webglAfter).toEqual(webglBefore + 1);

await tfc.setBackend('cpu');

const cpuBefore = tfc.engine().backend.numDataIds();

const inputReshaped2 = tfc.reshape(inputReshaped, [2, 2]);
// input moved to cpu.

// Because input is moved to cpu, data should be deleted from webgl.
expect(tfc.findBackend('webgl').numDataIds()).toEqual(webglAfter - 1);

const cpuAfter = tfc.engine().backend.numDataIds();

expect(cpuAfter).toEqual(cpuBefore + 1);

input.dispose();
expect(tfc.engine().backend.numDataIds()).toEqual(cpuAfter);

inputReshaped.dispose();

expect(tfc.engine().backend.numDataIds()).toEqual(cpuAfter);

inputReshaped2.dispose();

const after = tfc.engine().backend.numDataIds();

expect(after).toBe(cpuBefore);
});

it(`from cpu to webgl.`, async () => {
await tfc.setBackend('cpu');

const cpuBefore = tfc.engine().backend.numDataIds();

const input = tfc.tensor2d([1, 1, 1, 1], [2, 2], 'float32');
// input is stored in cpu backend.

const inputReshaped = tfc.reshape(input, [2, 2]);

const cpuAfter = tfc.engine().backend.numDataIds();

expect(cpuAfter).toEqual(cpuBefore + 1);

await tfc.setBackend('webgl');

const webglBefore = tfc.engine().backend.numDataIds();

const inputReshaped2 = tfc.reshape(inputReshaped, [2, 2]);
// input moved to webgl.

// Because input is moved to webgl, data should be deleted from cpu.
expect(tfc.findBackend('cpu').numDataIds()).toEqual(cpuAfter - 1);

const webglAfter = tfc.engine().backend.numDataIds();

expect(webglAfter).toEqual(webglBefore + 1);

input.dispose();

expect(tfc.engine().backend.numDataIds()).toEqual(webglAfter);

inputReshaped.dispose();

expect(tfc.engine().backend.numDataIds()).toEqual(webglAfter);

inputReshaped2.dispose();

const after = tfc.engine().backend.numDataIds();

expect(after).toBe(webglBefore);
});
});
});
101 changes: 101 additions & 0 deletions e2e/integration_tests/memory_leak_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import '@tensorflow/tfjs-backend-cpu';

import * as tfconverter from '@tensorflow/tfjs-converter';
import * as tfc from '@tensorflow/tfjs-core';

import {SMOKE} from './constants';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;

const CUSTOM_OP_MODEL = {
node: [
{
name: 'Input',
op: 'Placeholder',
attr: {
dtype: {
type: 1, // DT_FLOAT
},
shape: {shape: {dim: [{size: 4}]}}
}
},
{name: 'CustomOp', op: 'CustomOp', input: ['Input'], attr: {}}
],
versions: {producer: 1.0, minConsumer: 3}
};

const weightsManifest: tfc.io.WeightsManifestEntry[] =
[{'name': 'Const', 'dtype': 'float32', 'shape': [1]}];

const bias = tfc.tensor1d([0], 'float32');

const CUSTOM_HTTP_MODEL_LOADER = {
load: async () => {
return {
modelTopology: CUSTOM_OP_MODEL,
weightSpecs: weightsManifest,
weightData: bias.dataSync(),
format: 'tfjs-graph-model',
generatedBy: '1.15',
convertedBy: '1.3.1'
};
}
};

describe(
`${SMOKE} A custom op that calls unmodularized kernels and modularized ` +
`kernels`,
() => {
it('should have no memory leak in a model run.', async () => {
const model = new tfconverter.GraphModel(MODEL_URL);

spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);

// A custom op that calls unmodularized kernels and modularized kernels.
tfconverter.registerOp('CustomOp', (nodeValue) => {
const x = nodeValue.inputs[0];
const softMax = tfc.softmax(x);
const clone = tfc.clone(softMax);
return [tfc.reshape(clone, [2, 2])];
});

await model.load();

const before = tfc.memory().numTensors;

const input = tfc.tensor1d([1, 2, 3, 4]);
const output = model.predict(input) as tfc.Tensor;

tfc.test_util.expectArraysClose(await output.data(), [
0.032058604061603546, 0.08714432269334793, 0.23688283562660217,
0.6439142823219299
]);

input.dispose();
output.dispose();

const after = tfc.memory().numTensors;

expect(after).toEqual(before);
});
});
95 changes: 37 additions & 58 deletions tfjs-backend-cpu/src/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
*/

import * as tf from '@tensorflow/tfjs-core';
import {engine, env} from '@tensorflow/tfjs-core';
import {backend_util, buffer, slice_util, util} from '@tensorflow/tfjs-core';
import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, max, NumericDataType, Rank, reshape, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core';
import {kernel_impls} from '@tensorflow/tfjs-core';
import {backend_util, BackendTimingInfo, buffer, DataStorage, DataType, DataValues, engine, env, kernel_impls, KernelBackend, max, NumericDataType, Rank, Scalar, ShapeMap, slice_util, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TensorInfo, TypedArray, upcastType, util} from '@tensorflow/tfjs-core';

const nonMaxSuppressionV3Impl = kernel_impls.nonMaxSuppressionV3Impl;
const split = kernel_impls.split;
Expand Down Expand Up @@ -59,6 +56,9 @@ export interface TensorData<D extends DataType> {
// TODO(smilkov): Replace Tensor with TensorInfo when you modularize ops
// that work with complex tensors.
complexTensors?: {real: Tensor, imag: Tensor};
// refCount keeps track of how many tensors reference it. Used for memory
// management.
refCount: number;
}

export class MathBackendCPU extends KernelBackend {
Expand Down Expand Up @@ -91,14 +91,30 @@ export class MathBackendCPU extends KernelBackend {
}
}
const dataId = {};
this.data.set(dataId, {values, dtype});

this.data.set(dataId, {values, dtype, refCount: 1});

return dataId;
}

/** Increase refCount of a `TensorData`. */
incRef(dataId: DataId): void {
const tensorData = this.data.get(dataId);
tensorData.refCount++;
}

/** Decrease refCount of a `TensorData`. */
decRef(dataId: DataId): void {
if (this.data.has(dataId)) {
const tensorData = this.data.get(dataId);
tensorData.refCount--;
}
}

move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType): void {
this.data.set(dataId, {values, dtype});
this.data.set(dataId, {values, dtype, refCount: 1});
}

numDataIds(): number {
Expand Down Expand Up @@ -151,6 +167,20 @@ export class MathBackendCPU extends KernelBackend {
}
}

disposeIntermediateTensorInfo(tensorInfo: TensorInfo): void {
const dataId = tensorInfo.dataId;

if (this.data.has(dataId)) {
const tensorData = this.data.get(dataId);

tensorData.refCount--;

if (tensorData.refCount < 1) {
this.disposeData(dataId);
}
}
}

async time(f: () => void): Promise<BackendTimingInfo> {
const start = util.now();
f();
Expand Down Expand Up @@ -464,6 +494,7 @@ export class MathBackendCPU extends KernelBackend {
mapActivation(this, result, activation, preluActivationWeights) as
Tensor3D;
}

return result;
}

Expand Down Expand Up @@ -2061,27 +2092,6 @@ export class MathBackendCPU extends KernelBackend {
return tile(this.bufferSync(x), reps) as T;
}

pad<T extends Tensor>(
x: T, paddings: Array<[number, number]>, constantValue: number): T {
assertNotComplex(x, 'pad');

const outShape = paddings.map(
(p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
const start = paddings.map(p => p[0]);
const xBuffer = this.bufferSync(x);
const buffer = tf.buffer(outShape, x.dtype as 'float32');
if (constantValue !== 0) {
buffer.values.fill(constantValue);
}

for (let i = 0; i < x.size; i++) {
const coords = xBuffer.indexToLoc(i);
const outCoords = coords.map((c, i) => c + start[i]);
buffer.set(xBuffer.get(...coords), ...outCoords);
}
return buffer.toTensor() as T;
}

gather<T extends Tensor>(x: T, indices: Tensor1D, axis: number): T {
assertNotComplex([x, indices], 'gather');

Expand Down Expand Up @@ -2124,33 +2134,6 @@ export class MathBackendCPU extends KernelBackend {
.slice(sliceBeginCoords, sliceSize) as T;
}

spaceToBatchND<T extends Tensor>(
x: T, blockShape: number[], paddings: Array<[number, number]>): T {
assertNotComplex([x], 'spaceToBatchND');

const prod = blockShape.reduce((a, b) => a * b);

const completePaddings: Array<[number, number]> = [[0, 0]];
completePaddings.push(...paddings);
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}

const paddedX = x.pad(completePaddings);

const reshapedPaddedShape =
backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
const permutedReshapedPaddedPermutation = backend_util.getPermuted(
reshapedPaddedShape.length, blockShape.length, false);
const flattenShape = backend_util.getReshapedPermuted(
paddedX.shape, blockShape, prod, false);

const paddedXT = tf.transpose(
paddedX.reshape(reshapedPaddedShape),
permutedReshapedPaddedPermutation);
return reshape(paddedXT, flattenShape) as T;
}

maxPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D {
assertNotComplex(x, 'maxPool');
const xValues = this.readSync(x.dataId) as TypedArray;
Expand Down Expand Up @@ -2631,10 +2614,6 @@ export class MathBackendCPU extends KernelBackend {
return backend_util.castTensor(x, dtype, this);
}

reshape<R extends Rank>(x: Tensor, shape: ShapeMap[R]): Tensor<R> {
return backend_util.reshapeTensor(x, shape);
}

avgPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D {
assertNotComplex(x, 'avgPool');
assertNotComplex(x, 'maxPool');
Expand Down
Loading