Skip to content

Commit 172c577

Browse files
authored
[webgl] Modularize reshape, add refCounter. (#3910)
FEATURE
1 parent 2ef1c1c commit 172c577

File tree

12 files changed

+278
-66
lines changed

12 files changed

+278
-66
lines changed

tfjs-backend-webgl/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"publish-npm": "npm publish",
6464
"lint": "tslint -p . -t verbose",
6565
"test": "yarn && yarn build-deps && karma start",
66+
"test-dev": "karma start --testEnv webgl2",
6667
"run-browserstack": "karma start --browserstack",
6768
"test-ci": "./scripts/test-ci.sh"
6869
},

tfjs-backend-webgl/src/backend_webgl.ts

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,25 @@ export class MathBackendWebGL extends KernelBackend {
279279
}
280280
const dataId = {};
281281
this.texData.set(
282-
dataId, {shape, dtype, values, usage: TextureUsage.UPLOAD});
282+
dataId,
283+
{shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1});
283284
return dataId;
284285
}
285286

287+
/** Increase refCount of a `TextureData`. */
288+
incRef(dataId: DataId): void {
289+
const texData = this.texData.get(dataId);
290+
texData.refCount++;
291+
}
292+
293+
/** Decrease refCount of a `TextureData`. */
294+
decRef(dataId: DataId): void {
295+
if (this.texData.has(dataId)) {
296+
const texData = this.texData.get(dataId);
297+
texData.refCount--;
298+
}
299+
}
300+
286301
move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):
287302
void {
288303
if (env().getBool('DEBUG')) {
@@ -294,12 +309,31 @@ export class MathBackendWebGL extends KernelBackend {
294309
`Please use tf.complex(real, imag).`);
295310
}
296311
this.texData.set(
297-
dataId, {shape, dtype, values, usage: TextureUsage.UPLOAD});
312+
dataId,
313+
{shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1});
314+
}
315+
316+
disposeIntermediateTensorInfo(tensorInfo: TensorInfo): void {
317+
const dataId = tensorInfo.dataId;
318+
319+
if (this.texData.has(dataId)) {
320+
const textureData = this.texData.get(dataId);
321+
322+
textureData.refCount--;
323+
324+
if (textureData.refCount < 1) {
325+
this.disposeData(dataId);
326+
}
327+
}
298328
}
299329

300330
readSync(dataId: DataId): BackendValues {
301331
const texData = this.texData.get(dataId);
302332
const {values, dtype, complexTensors, slice, shape, isPacked} = texData;
333+
334+
// The presence of `slice` indicates this tensor is a shallow slice of a
335+
// different tensor, and is using that original tensor's texture. Run
336+
// `clone` in order to copy that texture and read from it.
303337
if (slice != null) {
304338
let program;
305339
if (isPacked) {
@@ -310,7 +344,7 @@ export class MathBackendWebGL extends KernelBackend {
310344
const res =
311345
this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);
312346
const data = this.readSync(res.dataId);
313-
this.disposeData(res.dataId);
347+
this.disposeIntermediateTensorInfo(res);
314348
return data;
315349
}
316350
if (values != null) {
@@ -348,6 +382,9 @@ export class MathBackendWebGL extends KernelBackend {
348382
const texData = this.texData.get(dataId);
349383
const {values, shape, slice, dtype, complexTensors, isPacked} = texData;
350384

385+
// The presence of `slice` indicates this tensor is a shallow slice of a
386+
// different tensor, and is using that original tensor's texture. Run
387+
// `clone` in order to copy that texture and read from it.
351388
if (slice != null) {
352389
let program;
353390
if (isPacked) {
@@ -358,7 +395,7 @@ export class MathBackendWebGL extends KernelBackend {
358395
const res =
359396
this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);
360397
const data = this.read(res.dataId);
361-
this.disposeData(res.dataId);
398+
this.disposeIntermediateTensorInfo(res);
362399
return data;
363400
}
364401

@@ -408,7 +445,7 @@ export class MathBackendWebGL extends KernelBackend {
408445
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
409446
}
410447
if (tmpDownloadTarget != null) {
411-
this.disposeData(tmpDownloadTarget.dataId);
448+
this.disposeIntermediateTensorInfo(tmpDownloadTarget);
412449
}
413450
const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
414451

@@ -454,7 +491,7 @@ export class MathBackendWebGL extends KernelBackend {
454491
tmpData.texture, ...tex_util.getDenseTexShape(shape))
455492
.subarray(0, size);
456493

457-
this.disposeData(tmpTarget.dataId);
494+
this.disposeIntermediateTensorInfo(tmpTarget);
458495

459496
return vals;
460497
}
@@ -474,7 +511,7 @@ export class MathBackendWebGL extends KernelBackend {
474511
.downloadByteEncodedFloatMatrixFromOutputTexture(
475512
tmpData.texture, tmpData.texShape[0], tmpData.texShape[1])
476513
.subarray(0, size);
477-
this.disposeData(output.dataId);
514+
this.disposeIntermediateTensorInfo(output);
478515

479516
return vals;
480517
}
@@ -1820,21 +1857,20 @@ export class MathBackendWebGL extends KernelBackend {
18201857
!reshapeWillBeExpensive) {
18211858
const targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] :
18221859
xShape[0] * xShape[2] * xShape[3];
1823-
const xReshaped = this.reshape(x, [1, targetShape, convInfo.inChannels]);
1860+
const xReshaped = reshape(x, [1, targetShape, convInfo.inChannels]);
18241861
const filterReshaped =
1825-
this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
1826-
1827-
return this.reshape<Rank.R4>(
1828-
this.fusedBatchMatMul({
1829-
a: xReshaped as Tensor3D,
1830-
b: filterReshaped as Tensor3D,
1831-
transposeA,
1832-
transposeB,
1833-
bias,
1834-
activation,
1835-
preluActivationWeights
1836-
}),
1837-
convInfo.outShape);
1862+
reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
1863+
1864+
const result = this.fusedBatchMatMul({
1865+
a: xReshaped as Tensor3D,
1866+
b: filterReshaped as Tensor3D,
1867+
transposeA,
1868+
transposeB,
1869+
bias,
1870+
activation,
1871+
preluActivationWeights
1872+
});
1873+
return reshape(result, convInfo.outShape);
18381874
}
18391875

18401876
// Following optimization is specific to packed |x| with odd row count
@@ -1869,7 +1905,7 @@ export class MathBackendWebGL extends KernelBackend {
18691905
() => `packed reshape ${xTexData.shape} to ${
18701906
xReshaped.shape} isn't free`);
18711907
const filterReshaped =
1872-
this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
1908+
reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
18731909

18741910
const pointwiseConv = this.fusedBatchMatMul({
18751911
a: xReshaped as Tensor3D,
@@ -2182,19 +2218,6 @@ export class MathBackendWebGL extends KernelBackend {
21822218
return result as Tensor5D;
21832219
}
21842220

2185-
reshape<R extends Rank>(x: Tensor, shape: ShapeMap[R]): Tensor<R> {
2186-
const texData = this.texData.get(x.dataId);
2187-
2188-
if (texData.isPacked && !webgl_util.isReshapeFree(x.shape, shape) &&
2189-
!(texData.texture !== null &&
2190-
webgl_util.isReshapeFree(texData.shape, shape))) {
2191-
const info = this.packedReshape(x, shape);
2192-
return engine().makeTensorFromDataId(
2193-
info.dataId, info.shape, info.dtype) as Tensor<R>;
2194-
}
2195-
return backend_util.reshapeTensor(x, shape);
2196-
}
2197-
21982221
resizeBilinear(
21992222
x: Tensor4D, newHeight: number, newWidth: number,
22002223
alignCorners: boolean): Tensor4D {
@@ -2575,7 +2598,7 @@ export class MathBackendWebGL extends KernelBackend {
25752598
gpgpu_math.runProgram(
25762599
this.gpgpu, binary, inputsData, outputData, customSetup);
25772600

2578-
dataToDispose.forEach(info => this.disposeData(info.dataId));
2601+
dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info));
25792602

25802603
if (shouldTimeProgram) {
25812604
query = this.endTimer(query);
@@ -2586,7 +2609,7 @@ export class MathBackendWebGL extends KernelBackend {
25862609
if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&
25872610
preventEagerUnpackingOfOutput === false) {
25882611
const unpacked = this.unpackTensor(output);
2589-
this.disposeData(output.dataId);
2612+
this.disposeIntermediateTensorInfo(output);
25902613
return unpacked;
25912614
}
25922615
return output;
@@ -2733,7 +2756,7 @@ export class MathBackendWebGL extends KernelBackend {
27332756
texData.isPacked = outputTexData.isPacked;
27342757
texData.usage = outputTexData.usage;
27352758

2736-
this.disposeData(tempDenseInputHandle.dataId);
2759+
this.disposeIntermediateTensorInfo(tempDenseInputHandle);
27372760
this.texData.delete(encodedOutputTarget.dataId);
27382761

27392762
// Once uploaded, don't store the values on cpu.

tfjs-backend-webgl/src/backend_webgl_test.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
182182
it('read packed and then use by an unpacked op', async () => {
183183
const backend = new MathBackendWebGL(null);
184184
tf.registerBackend('test-storage', () => backend);
185+
tf.copyRegisteredKernels('webgl', 'test-storage');
185186
tf.setBackend('test-storage');
186187

187188
const webglPackFlagSaved = tf.env().getBool('WEBGL_PACK');
@@ -711,15 +712,19 @@ describeWithFlags('caching on cpu', WEBGL_ENVS, () => {
711712

712713
describeWithFlags('WebGL backend has sync init', WEBGL_ENVS, () => {
713714
it('can do matmul without waiting for ready', async () => {
714-
tf.registerBackend('my-webgl', () => {
715+
const customWebGLBackendName = 'my-webgl';
716+
717+
tf.copyRegisteredKernels('webgl', customWebGLBackendName);
718+
719+
tf.registerBackend(customWebGLBackendName, () => {
715720
return new MathBackendWebGL();
716721
});
717-
tf.setBackend('my-webgl');
722+
tf.setBackend(customWebGLBackendName);
718723
const a = tf.tensor1d([5]);
719724
const b = tf.tensor1d([3]);
720725
const res = tf.dot(a, b);
721726
expectArraysClose(await res.data(), 15);
722727
tf.dispose([a, b, res]);
723-
tf.removeBackend('my-webgl');
728+
tf.removeBackend(customWebGLBackendName);
724729
});
725730
});

tfjs-backend-webgl/src/kernel_utils/reshape.ts

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ import {TensorInfo} from '@tensorflow/tfjs-core';
1919

2020
import {MathBackendWebGL} from '../backend_webgl';
2121
import {ReshapePackedProgram} from '../reshape_packed_gpu';
22-
import {getBatchDim, getRowsCols, isReshapeFree} from '../webgl_util';
22+
import {getBatchDim, getRowsCols} from '../webgl_util';
2323

24-
function packedReshape(
24+
export function packedReshape(
2525
input: TensorInfo, afterShape: number[],
2626
backend: MathBackendWebGL): TensorInfo {
2727
const input3DShape =
@@ -43,16 +43,3 @@ function packedReshape(
4343
preventEagerUnpackingOfOutput);
4444
return {dataId: output.dataId, shape: afterShape, dtype: output.dtype};
4545
}
46-
47-
export function reshape(
48-
x: TensorInfo, afterShape: number[],
49-
backend: MathBackendWebGL): TensorInfo {
50-
const xTexData = backend.texData.get(x.dataId);
51-
if (xTexData.isPacked && !isReshapeFree(x.shape, afterShape) &&
52-
!(xTexData.texture !== null &&
53-
isReshapeFree(xTexData.shape, afterShape))) {
54-
return packedReshape(x, afterShape, backend);
55-
}
56-
57-
return {dataId: x.dataId, shape: afterShape, dtype: x.dtype};
58-
}

tfjs-backend-webgl/src/kernels/Max.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ export const maxConfig: KernelConfig = {
8989
}
9090

9191
if (maxInputIsTransposed) {
92-
webglBackend.disposeData(maxInput.dataId);
92+
webglBackend.disposeIntermediateTensorInfo(maxInput);
9393
}
9494

9595
return out;

tfjs-backend-webgl/src/kernels/Max_impl.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,23 @@ import {TensorInfo, util} from '@tensorflow/tfjs-core';
1919

2020
import {MathBackendWebGL} from '../backend_webgl';
2121
import {reduce} from '../kernel_utils/reduce';
22-
import {reshape} from '../kernel_utils/reshape';
22+
import {reshape} from '../kernels/Reshape';
2323

2424
export function maxImpl(
2525
x: TensorInfo, reduceShape: number[], outShape: number[],
2626
backend: MathBackendWebGL): TensorInfo {
2727
const inSize = util.sizeFromShape(reduceShape);
2828
const xSize = util.sizeFromShape(x.shape);
2929
const batchSize = xSize / inSize;
30-
const reshapedInput = reshape(x, [batchSize, inSize], backend);
30+
const reshapedInput =
31+
reshape({inputs: {x}, attrs: {shape: [batchSize, inSize]}, backend});
32+
3133
const reduced = reduce(reshapedInput, x.dtype, 'max', backend);
34+
const reshapedOutput =
35+
reshape({inputs: {x: reduced}, attrs: {shape: outShape}, backend});
3236

33-
if (reshapedInput.dataId !== x.dataId) {
34-
// dispose the output of the packed reshape.
35-
backend.disposeData(reshapedInput.dataId);
36-
}
37+
backend.disposeIntermediateTensorInfo(reshapedInput);
38+
backend.disposeIntermediateTensorInfo(reduced);
3739

38-
return reshape(reduced, outShape, backend);
40+
return reshapedOutput;
3941
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/**
2+
* @license
3+
* Copyright 2020 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, KernelFunc, Reshape, ReshapeAttrs, ReshapeInputs, TensorInfo, util} from '@tensorflow/tfjs-core';
19+
20+
import {MathBackendWebGL} from '../backend_webgl';
21+
import {packedReshape} from '../kernel_utils/reshape';
22+
import {isReshapeFree} from '../webgl_util';
23+
24+
export function reshape(args: {
25+
inputs: ReshapeInputs,
26+
backend: MathBackendWebGL,
27+
attrs: ReshapeAttrs
28+
}): TensorInfo {
29+
const {inputs, backend, attrs} = args;
30+
const {x} = inputs;
31+
const {shape} = attrs;
32+
const webglBackend = backend;
33+
34+
const xSize = util.sizeFromShape(x.shape);
35+
const $shape = util.inferFromImplicitShape(shape, xSize);
36+
const $xSize = util.sizeFromShape($shape);
37+
38+
util.assert(
39+
xSize === $xSize,
40+
() => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
41+
`shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
42+
`shape must have the same number of elements.`);
43+
44+
const xTexData = webglBackend.texData.get(x.dataId);
45+
if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) &&
46+
!(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
47+
return packedReshape(x, $shape, webglBackend);
48+
}
49+
50+
webglBackend.incRef(x.dataId);
51+
52+
return {dataId: x.dataId, shape: $shape, dtype: x.dtype};
53+
}
54+
55+
export const reshapeConfig: KernelConfig = {
56+
kernelName: Reshape,
57+
backendName: 'webgl',
58+
kernelFunc: reshape as {} as KernelFunc
59+
};

0 commit comments

Comments
 (0)