diff --git a/e2e/integration_tests/backends_test.ts b/e2e/integration_tests/backends_test.ts new file mode 100644 index 00000000000..139ede3ea84 --- /dev/null +++ b/e2e/integration_tests/backends_test.ts @@ -0,0 +1,119 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import '@tensorflow/tfjs-backend-cpu'; +import '@tensorflow/tfjs-backend-webgl'; + +import * as tfc from '@tensorflow/tfjs-core'; + +import {SMOKE} from './constants'; + +/** + * This file tests backend switching scenario. + */ +describe(`${SMOKE} backends`, () => { + describe('switch', () => { + beforeAll(() => { + tfc.env().set('WEBGL_CPU_FORWARD', false); + }); + + it(`from webgl to cpu.`, async () => { + await tfc.setBackend('webgl'); + + const webglBefore = tfc.engine().backend.numDataIds(); + + const input = tfc.tensor2d([1, 1, 1, 1], [2, 2], 'float32'); + // input is stored in webgl backend. + + const inputReshaped = tfc.reshape(input, [2, 2]); + + const webglAfter = tfc.engine().backend.numDataIds(); + + expect(webglAfter).toEqual(webglBefore + 1); + + await tfc.setBackend('cpu'); + + const cpuBefore = tfc.engine().backend.numDataIds(); + + const inputReshaped2 = tfc.reshape(inputReshaped, [2, 2]); + // input moved to cpu. + + // Because input is moved to cpu, data should be deleted from webgl. + expect(tfc.findBackend('webgl').numDataIds()).toEqual(webglAfter - 1); + + const cpuAfter = tfc.engine().backend.numDataIds(); + + expect(cpuAfter).toEqual(cpuBefore + 1); + + input.dispose(); + expect(tfc.engine().backend.numDataIds()).toEqual(cpuAfter); + + inputReshaped.dispose(); + + expect(tfc.engine().backend.numDataIds()).toEqual(cpuAfter); + + inputReshaped2.dispose(); + + const after = tfc.engine().backend.numDataIds(); + + expect(after).toBe(cpuBefore); + }); + + it(`from cpu to webgl.`, async () => { + await tfc.setBackend('cpu'); + + const cpuBefore = tfc.engine().backend.numDataIds(); + + const input = tfc.tensor2d([1, 1, 1, 1], [2, 2], 'float32'); + // input is stored in cpu backend. + + const inputReshaped = tfc.reshape(input, [2, 2]); + + const cpuAfter = tfc.engine().backend.numDataIds(); + + expect(cpuAfter).toEqual(cpuBefore + 1); + + await tfc.setBackend('webgl'); + + const webglBefore = tfc.engine().backend.numDataIds(); + + const inputReshaped2 = tfc.reshape(inputReshaped, [2, 2]); + // input moved to webgl. + + // Because input is moved to webgl, data should be deleted from cpu. + expect(tfc.findBackend('cpu').numDataIds()).toEqual(cpuAfter - 1); + + const webglAfter = tfc.engine().backend.numDataIds(); + + expect(webglAfter).toEqual(webglBefore + 1); + + input.dispose(); + + expect(tfc.engine().backend.numDataIds()).toEqual(webglAfter); + + inputReshaped.dispose(); + + expect(tfc.engine().backend.numDataIds()).toEqual(webglAfter); + + inputReshaped2.dispose(); + + const after = tfc.engine().backend.numDataIds(); + + expect(after).toBe(webglBefore); + }); + }); +}); diff --git a/e2e/integration_tests/memory_leak_test.ts b/e2e/integration_tests/memory_leak_test.ts new file mode 100644 index 00000000000..9ce0a2c96b0 --- /dev/null +++ b/e2e/integration_tests/memory_leak_test.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import '@tensorflow/tfjs-backend-cpu'; + +import * as tfconverter from '@tensorflow/tfjs-converter'; +import * as tfc from '@tensorflow/tfjs-core'; + +import {SMOKE} from './constants'; + +const HOST = 'http://example.org'; +const MODEL_URL = `${HOST}/model.json`; + +const CUSTOM_OP_MODEL = { + node: [ + { + name: 'Input', + op: 'Placeholder', + attr: { + dtype: { + type: 1, // DT_FLOAT + }, + shape: {shape: {dim: [{size: 4}]}} + } + }, + {name: 'CustomOp', op: 'CustomOp', input: ['Input'], attr: {}} + ], + versions: {producer: 1.0, minConsumer: 3} +}; + +const weightsManifest: tfc.io.WeightsManifestEntry[] = + [{'name': 'Const', 'dtype': 'float32', 'shape': [1]}]; + +const bias = tfc.tensor1d([0], 'float32'); + +const CUSTOM_HTTP_MODEL_LOADER = { + load: async () => { + return { + modelTopology: CUSTOM_OP_MODEL, + weightSpecs: weightsManifest, + weightData: bias.dataSync(), + format: 'tfjs-graph-model', + generatedBy: '1.15', + convertedBy: '1.3.1' + }; + } +}; + +describe( + `${SMOKE} A custom op that calls unmodularized kernels and modularized ` + + `kernels`, + () => { + it('should have no memory leak in a model run.', async () => { + const model = new tfconverter.GraphModel(MODEL_URL); + + spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + CUSTOM_HTTP_MODEL_LOADER + ]); + + // A custom op that calls unmodularized kernels and modularized kernels. + tfconverter.registerOp('CustomOp', (nodeValue) => { + const x = nodeValue.inputs[0]; + const softMax = tfc.softmax(x); + const clone = tfc.clone(softMax); + return [tfc.reshape(clone, [2, 2])]; + }); + + await model.load(); + + const before = tfc.memory().numTensors; + + const input = tfc.tensor1d([1, 2, 3, 4]); + const output = model.predict(input) as tfc.Tensor; + + tfc.test_util.expectArraysClose(await output.data(), [ + 0.032058604061603546, 0.08714432269334793, 0.23688283562660217, + 0.6439142823219299 + ]); + + input.dispose(); + output.dispose(); + + const after = tfc.memory().numTensors; + + expect(after).toEqual(before); + }); + }); diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 8bf6be1794a..62341737027 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -16,10 +16,7 @@ */ import * as tf from '@tensorflow/tfjs-core'; -import {engine, env} from '@tensorflow/tfjs-core'; -import {backend_util, buffer, slice_util, util} from '@tensorflow/tfjs-core'; -import {BackendTimingInfo, DataStorage, DataType, DataValues, KernelBackend, max, NumericDataType, Rank, reshape, Scalar, ShapeMap, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TypedArray, upcastType} from '@tensorflow/tfjs-core'; -import {kernel_impls} from '@tensorflow/tfjs-core'; +import {backend_util, BackendTimingInfo, buffer, DataStorage, DataType, DataValues, engine, env, kernel_impls, KernelBackend, max, NumericDataType, Rank, Scalar, ShapeMap, slice_util, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, TensorInfo, TypedArray, upcastType, util} from '@tensorflow/tfjs-core'; const nonMaxSuppressionV3Impl = kernel_impls.nonMaxSuppressionV3Impl; const split = kernel_impls.split; @@ -59,6 +56,9 @@ export interface TensorData { // TODO(smilkov): Replace Tensor with TensorInfo when you modularize ops // that work with complex tensors. complexTensors?: {real: Tensor, imag: Tensor}; + // refCount keeps track of how many tensors reference it. Used for memory + // management. + refCount: number; } export class MathBackendCPU extends KernelBackend { @@ -91,14 +91,30 @@ export class MathBackendCPU extends KernelBackend { } } const dataId = {}; - this.data.set(dataId, {values, dtype}); + + this.data.set(dataId, {values, dtype, refCount: 1}); + return dataId; } + /** Increase refCount of a `TensorData`. */ + incRef(dataId: DataId): void { + const tensorData = this.data.get(dataId); + tensorData.refCount++; + } + + /** Decrease refCount of a `TensorData`. */ + decRef(dataId: DataId): void { + if (this.data.has(dataId)) { + const tensorData = this.data.get(dataId); + tensorData.refCount--; + } + } + move( dataId: DataId, values: backend_util.BackendValues, shape: number[], dtype: DataType): void { - this.data.set(dataId, {values, dtype}); + this.data.set(dataId, {values, dtype, refCount: 1}); } numDataIds(): number { @@ -151,6 +167,20 @@ export class MathBackendCPU extends KernelBackend { } } + disposeIntermediateTensorInfo(tensorInfo: TensorInfo): void { + const dataId = tensorInfo.dataId; + + if (this.data.has(dataId)) { + const tensorData = this.data.get(dataId); + + tensorData.refCount--; + + if (tensorData.refCount < 1) { + this.disposeData(dataId); + } + } + } + async time(f: () => void): Promise { const start = util.now(); f(); @@ -464,6 +494,7 @@ export class MathBackendCPU extends KernelBackend { mapActivation(this, result, activation, preluActivationWeights) as Tensor3D; } + return result; } @@ -2061,27 +2092,6 @@ export class MathBackendCPU extends KernelBackend { return tile(this.bufferSync(x), reps) as T; } - pad( - x: T, paddings: Array<[number, number]>, constantValue: number): T { - assertNotComplex(x, 'pad'); - - const outShape = paddings.map( - (p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */); - const start = paddings.map(p => p[0]); - const xBuffer = this.bufferSync(x); - const buffer = tf.buffer(outShape, x.dtype as 'float32'); - if (constantValue !== 0) { - buffer.values.fill(constantValue); - } - - for (let i = 0; i < x.size; i++) { - const coords = xBuffer.indexToLoc(i); - const outCoords = coords.map((c, i) => c + start[i]); - buffer.set(xBuffer.get(...coords), ...outCoords); - } - return buffer.toTensor() as T; - } - gather(x: T, indices: Tensor1D, axis: number): T { assertNotComplex([x, indices], 'gather'); @@ -2124,33 +2134,6 @@ export class MathBackendCPU extends KernelBackend { .slice(sliceBeginCoords, sliceSize) as T; } - spaceToBatchND( - x: T, blockShape: number[], paddings: Array<[number, number]>): T { - assertNotComplex([x], 'spaceToBatchND'); - - const prod = blockShape.reduce((a, b) => a * b); - - const completePaddings: Array<[number, number]> = [[0, 0]]; - completePaddings.push(...paddings); - for (let i = 1 + blockShape.length; i < x.shape.length; ++i) { - completePaddings.push([0, 0]); - } - - const paddedX = x.pad(completePaddings); - - const reshapedPaddedShape = - backend_util.getReshaped(paddedX.shape, blockShape, prod, false); - const permutedReshapedPaddedPermutation = backend_util.getPermuted( - reshapedPaddedShape.length, blockShape.length, false); - const flattenShape = backend_util.getReshapedPermuted( - paddedX.shape, blockShape, prod, false); - - const paddedXT = tf.transpose( - paddedX.reshape(reshapedPaddedShape), - permutedReshapedPaddedPermutation); - return reshape(paddedXT, flattenShape) as T; - } - maxPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { assertNotComplex(x, 'maxPool'); const xValues = this.readSync(x.dataId) as TypedArray; @@ -2631,10 +2614,6 @@ export class MathBackendCPU extends KernelBackend { return backend_util.castTensor(x, dtype, this); } - reshape(x: Tensor, shape: ShapeMap[R]): Tensor { - return backend_util.reshapeTensor(x, shape); - } - avgPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { assertNotComplex(x, 'avgPool'); assertNotComplex(x, 'maxPool'); diff --git a/tfjs-backend-cpu/src/backend_cpu_test.ts b/tfjs-backend-cpu/src/backend_cpu_test.ts index 768ffe578db..b4aa455dfdc 100644 --- a/tfjs-backend-cpu/src/backend_cpu_test.ts +++ b/tfjs-backend-cpu/src/backend_cpu_test.ts @@ -17,7 +17,8 @@ import * as tf from '@tensorflow/tfjs-core'; import {engine, test_util, util} from '@tensorflow/tfjs-core'; -const {expectArraysClose, expectArraysEqual} = test_util; + +const {expectArraysEqual} = test_util; // tslint:disable-next-line: no-imports-from-dist import {describeWithFlags, ALL_ENVS} from '@tensorflow/tfjs-core/dist/jasmine_util'; @@ -132,18 +133,3 @@ describeWithFlags('memory cpu', ALL_ENVS, () => { expect(mem.reasons.indexOf(expectedReasonString) >= 0).toBe(true); }); }); - -describeWithFlags('CPU backend has sync init', ALL_ENVS, () => { - it('can do matmul without waiting for ready', async () => { - tf.registerBackend('my-cpu', () => { - return new MathBackendCPU(); - }); - tf.setBackend('my-cpu'); - const a = tf.tensor1d([5]); - const b = tf.tensor1d([3]); - const res = tf.dot(a, b); - expectArraysClose(await res.data(), 15); - tf.dispose([a, b, res]); - tf.removeBackend('my-cpu'); - }); -}); diff --git a/tfjs-backend-cpu/src/kernels/PadV2.ts b/tfjs-backend-cpu/src/kernels/PadV2.ts new file mode 100644 index 00000000000..d82462a8c86 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/PadV2.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, NumericDataType, PadV2, PadV2Attrs, PadV2Inputs, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +export function padV2( + args: {inputs: PadV2Inputs, backend: MathBackendCPU, attrs: PadV2Attrs}): + TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const {paddings, constantValue} = attrs; + + assertNotComplex(x, 'pad'); + + const outShape = paddings.map( + (p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */); + + const start = paddings.map(p => p[0]); + + const xVals = backend.data.get(x.dataId).values as TypedArray; + const xSize = util.sizeFromShape(x.shape); + const xRank = x.shape.length; + const xStrides = util.computeStrides(x.shape); + + const resultSize = util.sizeFromShape(outShape); + const resultRank = outShape.length; + const resultStrides = util.computeStrides(outShape); + const resVals = + util.getTypedArrayFromDType(x.dtype as NumericDataType, resultSize); + + if (constantValue !== 0) { + resVals.fill(constantValue); + } + + for (let i = 0; i < xSize; i++) { + const coords = util.indexToLoc(i, xRank, xStrides); + const outCoords = coords.map((c, i) => c + start[i]); + const outIndex = util.locToIndex(outCoords, resultRank, resultStrides); + + resVals[outIndex] = xVals[i]; + } + + const outId = backend.write(resVals, outShape, x.dtype); + + return {dataId: outId, shape: outShape, dtype: x.dtype}; +} + +export const padV2Config: KernelConfig = { + kernelName: PadV2, + backendName: 'cpu', + kernelFunc: padV2 as {} as KernelFunc +}; diff --git a/tfjs-backend-cpu/src/kernels/Reshape.ts b/tfjs-backend-cpu/src/kernels/Reshape.ts new file mode 100644 index 00000000000..e7d5e359266 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Reshape.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, Reshape, ReshapeAttrs, ReshapeInputs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export function reshape( + args: + {inputs: ReshapeInputs, backend: MathBackendCPU, attrs: ReshapeAttrs}): + TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const {shape} = attrs; + + backend.incRef(x.dataId); + + return {dataId: x.dataId, shape, dtype: x.dtype}; +} + +export const reshapeConfig: KernelConfig = { + kernelName: Reshape, + backendName: 'cpu', + kernelFunc: reshape as {} as KernelFunc +}; diff --git a/tfjs-backend-cpu/src/kernels/Reshape_test.ts b/tfjs-backend-cpu/src/kernels/Reshape_test.ts new file mode 100644 index 00000000000..e92fb728819 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Reshape_test.ts @@ -0,0 +1,78 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '@tensorflow/tfjs-core'; +import {Tensor, test_util} from '@tensorflow/tfjs-core'; + +const {expectArraysClose, expectArraysEqual} = test_util; +// tslint:disable-next-line: no-imports-from-dist +import {describeWithFlags, ALL_ENVS} from '@tensorflow/tfjs-core/dist/jasmine_util'; + +describeWithFlags('Reshape.', ALL_ENVS, () => { + it('does not have memory leak.', async () => { + const beforeDataIds = tf.engine().backend.numDataIds(); + + const x = tf.tensor1d([1, 1, 1, 1]); + const res = + tf.engine().runKernel('Reshape', {x}, {shape: [2, 2]}) as Tensor; + + expectArraysClose(await res.data(), [1, 1, 1, 1]); + expectArraysEqual(res.shape, [2, 2]); + + const afterResDataIds = tf.engine().backend.numDataIds(); + expect(afterResDataIds).toEqual(beforeDataIds + 1); + + x.dispose(); + res.dispose(); + + const afterDisposeDataIds = tf.engine().backend.numDataIds(); + expect(afterDisposeDataIds).toEqual(beforeDataIds); + }); + + it('does not have memory leak calling reshape twice.', async () => { + const beforeDataIds = tf.engine().backend.numDataIds(); + + // Adding 1 new dataId. + const x = tf.tensor1d([1, 1, 1, 1]); + + // Does not add new dataId; + const res = + tf.engine().runKernel('Reshape', {x}, {shape: [2, 2]}) as Tensor; + + expectArraysEqual(res.shape, [2, 2]); + + // Does not add new dataId. + const res2 = + tf.engine().runKernel('Reshape', {x: res}, {shape: [1, 4]}) as Tensor; + expectArraysEqual(res2.shape, [1, 4]); + + const afterRes2DataIds = tf.engine().backend.numDataIds(); + expect(afterRes2DataIds).toEqual(beforeDataIds + 1); + + res.dispose(); + + const afterResDataIds = tf.engine().backend.numDataIds(); + expect(afterResDataIds).toEqual(beforeDataIds + 1); + + x.dispose(); + res2.dispose(); + + const afterDisposeDataIds = tf.engine().backend.numDataIds(); + // Should be able to dispose the dataId. + expect(afterDisposeDataIds).toEqual(beforeDataIds); + }); +}); diff --git a/tfjs-backend-cpu/src/kernels/SpaceToBatchND.ts b/tfjs-backend-cpu/src/kernels/SpaceToBatchND.ts new file mode 100644 index 00000000000..1e2796d4e44 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/SpaceToBatchND.ts @@ -0,0 +1,89 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, KernelConfig, KernelFunc, ReshapeAttrs, ReshapeInputs, SpaceToBatchND, SpaceToBatchNDAttrs, SpaceToBatchNDInputs, TensorInfo, TransposeAttrs, TransposeInputs, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +import {padV2Config} from './PadV2'; +import {reshape} from './Reshape'; +import {transpose} from './Transpose'; + +export function spaceToBatchND(args: { + inputs: SpaceToBatchNDInputs, + backend: MathBackendCPU, + attrs: SpaceToBatchNDAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const {blockShape, paddings} = attrs; + + assertNotComplex([x], 'spaceToBatchND'); + + const prod = util.sizeFromShape(blockShape); + + const completePaddings: Array<[number, number]> = [[0, 0]]; + completePaddings.push(...(paddings as Array<[number, number]>)); + + for (let i = 1 + blockShape.length; i < x.shape.length; ++i) { + completePaddings.push([0, 0]); + } + + const paddedX = padV2Config.kernelFunc({ + inputs: {x}, + backend, + attrs: {paddings: completePaddings, constantValue: 0} + }) as TensorInfo; + + const reshapedPaddedShape = + backend_util.getReshaped(paddedX.shape, blockShape, prod, false); + + const permutedReshapedPaddedPermutation = backend_util.getPermuted( + reshapedPaddedShape.length, blockShape.length, false); + + const flattenShape = + backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false); + + const reshapeInputs: ReshapeInputs = {x: paddedX}; + const reshapeAttrs: ReshapeAttrs = {shape: reshapedPaddedShape}; + const paddedXReshaped = + reshape({inputs: reshapeInputs, backend, attrs: reshapeAttrs}); + + const transposeInputs: TransposeInputs = {x: paddedXReshaped}; + const transposeAttrs: + TransposeAttrs = {perm: permutedReshapedPaddedPermutation}; + const paddedXT = + transpose({inputs: transposeInputs, backend, attrs: transposeAttrs}); + + const resultReshapeInputs: ReshapeInputs = {x: paddedXT}; + const resultReshapeAttrs: ReshapeAttrs = {shape: flattenShape}; + const result = reshape( + {inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs}); + + backend.disposeIntermediateTensorInfo(paddedX); + backend.disposeIntermediateTensorInfo(paddedXReshaped); + backend.disposeIntermediateTensorInfo(paddedXT); + + return result; +} + +export const spaceToBatchNDConfig: KernelConfig = { + kernelName: SpaceToBatchND, + backendName: 'cpu', + kernelFunc: spaceToBatchND as {} as KernelFunc +}; diff --git a/tfjs-backend-cpu/src/kernels/Transpose.ts b/tfjs-backend-cpu/src/kernels/Transpose.ts index 05ae13f03cb..6935c118550 100644 --- a/tfjs-backend-cpu/src/kernels/Transpose.ts +++ b/tfjs-backend-cpu/src/kernels/Transpose.ts @@ -15,35 +15,40 @@ * ============================================================================= */ -import {KernelConfig, TypedArray} from '@tensorflow/tfjs-core'; +import {KernelConfig, KernelFunc, TensorInfo, Transpose, TransposeAttrs, TransposeInputs, TypedArray} from '@tensorflow/tfjs-core'; -import {Transpose, TransposeAttrs, TransposeInputs} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; import {transposeImpl} from './Transpose_impl'; -export const transposeConfig: KernelConfig = { - kernelName: Transpose, - backendName: 'cpu', - kernelFunc: ({inputs, attrs, backend}) => { - const {x} = inputs as TransposeInputs; - const {perm} = attrs as {} as TransposeAttrs; - const cpuBackend = backend as MathBackendCPU; +export function transpose(args: { + inputs: TransposeInputs, + attrs: TransposeAttrs, + backend: MathBackendCPU +}): TensorInfo { + const {inputs, attrs, backend} = args; + const {x} = inputs; + const {perm} = attrs; - assertNotComplex(x, 'transpose'); + assertNotComplex(x, 'transpose'); - const xRank = x.shape.length; + const xRank = x.shape.length; - const newShape: number[] = new Array(xRank); - for (let i = 0; i < newShape.length; i++) { - newShape[i] = x.shape[perm[i]]; - } + const newShape: number[] = new Array(xRank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[perm[i]]; + } - const values = cpuBackend.data.get(x.dataId).values as TypedArray; - const result = transposeImpl(values, x.shape, x.dtype, perm, newShape); + const values = backend.data.get(x.dataId).values as TypedArray; + const result = transposeImpl(values, x.shape, x.dtype, perm, newShape); - const dataId = cpuBackend.write(result, newShape, x.dtype); - return {dataId, shape: newShape, dtype: x.dtype}; - } + const dataId = backend.write(result, newShape, x.dtype); + return {dataId, shape: newShape, dtype: x.dtype}; +} + +export const transposeConfig: KernelConfig = { + kernelName: Transpose, + backendName: 'cpu', + kernelFunc: transpose as {} as KernelFunc }; diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index a03d8fcca72..70ec59226df 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -28,7 +28,10 @@ import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; +import {padV2Config} from './kernels/PadV2'; +import {reshapeConfig} from './kernels/Reshape'; import {rotateWithOffsetConfig} from './kernels/RotateWithOffset'; +import {spaceToBatchNDConfig} from './kernels/SpaceToBatchND'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; import {transposeConfig} from './kernels/Transpose'; @@ -38,8 +41,8 @@ const kernelConfigs: KernelConfig[] = [ dilation2dConfig, dilation2dBackpropInputConfig, dilation2dBackpropFilterConfig, divConfig, flipLeftRightConfig, maxPoolWithArgmaxConfig, maxConfig, nonMaxSuppressionV4Config, - nonMaxSuppressionV5Config, rotateWithOffsetConfig, squareConfig, - squaredDifferenceConfig, transposeConfig + nonMaxSuppressionV5Config, padV2Config, reshapeConfig, rotateWithOffsetConfig, + spaceToBatchNDConfig, squareConfig, squaredDifferenceConfig, transposeConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index cf3c8d64258..a939a659304 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -2200,6 +2200,7 @@ export class MathBackendWebGL extends KernelBackend { reshape(x: Tensor, shape: ShapeMap[R]): Tensor { const texData = this.texData.get(x.dataId); + if (texData.isPacked && !webgl_util.isReshapeFree(x.shape, shape) && !(texData.texture !== null && webgl_util.isReshapeFree(texData.shape, shape))) { diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 6d6c309df1a..3412fa01b78 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -802,7 +802,9 @@ export class Engine implements TensorTracker, DataMover { }); this.state.numBytes += bytes; } + this.state.tensorInfo.get(a.dataId).refCount++; + if (!(a instanceof Variable)) { this.track(a); } @@ -819,6 +821,7 @@ export class Engine implements TensorTracker, DataMover { } const info = this.state.tensorInfo.get(a.dataId); const refCount = info.refCount; + if (refCount <= 1) { // Don't count bytes for complex numbers as they are counted by their // components. @@ -826,6 +829,7 @@ export class Engine implements TensorTracker, DataMover { this.state.numBytes -= info.bytes; } this.state.numDataBuffers--; + info.backend.disposeData(a.dataId); this.state.tensorInfo.delete(a.dataId); } else { diff --git a/tfjs-core/src/engine_test.ts b/tfjs-core/src/engine_test.ts index f68c0ae42df..f6178b2103f 100644 --- a/tfjs-core/src/engine_test.ts +++ b/tfjs-core/src/engine_test.ts @@ -766,7 +766,7 @@ describe('Memory allocation outside a test scope', () => { }, read: async (dataId: object) => storedValues, dispose: () => null, - disposeData: (dataId: {}) => null, + disposeData: (dataId: {}) => null } as TestStorage; }); tf.setBackend(backendName);