From 696e4592a75c1d2df820b7d955284df1eae148e4 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Thu, 24 Sep 2020 14:54:06 -0700 Subject: [PATCH 1/3] save --- tfjs-backend-webgl/src/backend_webgl.ts | 38 +-------- tfjs-backend-webgl/src/kernels/AvgPool.ts | 62 +++++++++++++++ .../src/kernels/AvgPoolBackprop.ts | 56 ++++++++++++++ tfjs-backend-webgl/src/kernels/MaxPool.ts | 61 +++++++++++++++ .../src/kernels/MaxPoolBackprop.ts | 77 +++++++++++++++++++ .../src/register_all_kernels.ts | 31 ++++++-- tfjs-backend-webgl/src/webgl_util.ts | 17 +++- 7 files changed, 301 insertions(+), 41 deletions(-) create mode 100644 tfjs-backend-webgl/src/kernels/AvgPool.ts create mode 100644 tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts create mode 100644 tfjs-backend-webgl/src/kernels/MaxPool.ts create mode 100644 tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index 18c238c6c1b..d6a41def911 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -33,7 +33,7 @@ import {AddNProgram} from './addn_gpu'; import {AddNPackedProgram} from './addn_packed_gpu'; import {ArgMinMaxProgram} from './argminmax_gpu'; import {ArgMinMaxPackedProgram} from './argminmax_packed_gpu'; -import {AvgPool2DBackpropProgram, AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu'; +import {AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu'; import * as binaryop_complex_gpu from './binaryop_complex_gpu'; import {BinaryOpComplexProgram} from './binaryop_complex_gpu'; import * as binaryop_gpu from './binaryop_gpu'; @@ -73,14 +73,14 @@ import {Im2ColPackedProgram} from './im2col_packed_gpu'; import {LRNProgram} from './lrn_gpu'; import {LRNGradProgram} from './lrn_grad_gpu'; import {LRNPackedProgram} from './lrn_packed_gpu'; -import {MaxPool2DBackpropProgram, MaxPool3DBackpropProgram} from './max_pool_backprop_gpu'; +import {MaxPool3DBackpropProgram} from './max_pool_backprop_gpu'; import {MatMulPackedProgram} from './mulmat_packed_gpu'; import {MultinomialProgram} from './multinomial_gpu'; import {OneHotProgram} from './onehot_gpu'; import {PackProgram} from './pack_gpu'; import {PadProgram} from './pad_gpu'; import {PadPackedProgram} from './pad_packed_gpu'; -import {Pool2DProgram, Pool3DProgram} from './pool_gpu'; +import {Pool3DProgram} from './pool_gpu'; import {ReduceProgram} from './reduce_gpu'; import {ReshapePackedProgram} from './reshape_packed_gpu'; import {ResizeBilinearBackpropProgram} from './resize_bilinear_backprop_gpu'; @@ -2097,38 +2097,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [x, dy]); } - maxPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { - const program = new Pool2DProgram(convInfo, 'max', false); - return this.compileAndRun(program, [x]); - } - - avgPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { - const program = new Pool2DProgram(convInfo, 'avg', false); - return this.compileAndRun(program, [x], 'float32'); - } - - maxPoolBackprop( - dy: Tensor4D, x: Tensor4D, y: Tensor4D, - convInfo: backend_util.Conv2DInfo): Tensor4D { - const getPositions = true; - const maxPoolPositionsProgram = - new Pool2DProgram(convInfo, 'max', getPositions); - const maxPoolPositions: Tensor4D = - this.compileAndRun(maxPoolPositionsProgram, [x]); - - const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); - const result = this.compileAndRun( - maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); - maxPoolPositions.dispose(); - return result as Tensor4D; - } - - avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: backend_util.Conv2DInfo): - Tensor4D { - const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); - return this.compileAndRun(avgPoolBackpropProgram, [dy], x.dtype); - } - cast(x: T, dtype: DataType): T { return backend_util.castTensor(x, dtype, this); } diff --git a/tfjs-backend-webgl/src/kernels/AvgPool.ts b/tfjs-backend-webgl/src/kernels/AvgPool.ts new file mode 100644 index 00000000000..9eb2f51d023 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/AvgPool.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {AvgPool, AvgPoolAttrs, AvgPoolInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {Pool2DProgram} from '../pool_gpu'; +import {assertNotComplex} from '../webgl_util'; + +export function avgPool(args: { + inputs: AvgPoolInputs, + backend: MathBackendWebGL, + attrs: AvgPoolAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + assertNotComplex(x, 'avgPool'); + const {filterSize, strides, pad, dimRoundingMode} = attrs; + const dilations = 1; + + util.assert( + backend_util.eitherStridesOrDilationsAreOne(strides, dilations), + () => 'Error in avgPool: Either strides or dilations must be 1. ' + + `Got strides ${strides} and dilations '${dilations}'`); + + const xRank = x.shape.length; + util.assert( + xRank === 4, + () => `Error in avgPool: input must be rank 4 but got rank ${ + x.shape.length}}.`); + if (dimRoundingMode != null) { + util.assert( + util.isInt(pad as number), + () => `Error in avgPool: pad must be an integer when using, ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + dilations, pad, dimRoundingMode); + const program = new Pool2DProgram(convInfo, 'avg', false); + return backend.runWebGLProgram(program, [x], 'float32'); +} + +export const avgPoolConfig: KernelConfig = { + kernelName: AvgPool, + backendName: 'webgl', + kernelFunc: avgPool as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts b/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts new file mode 100644 index 00000000000..046d3664411 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {AvgPoolBackprop, AvgPoolBackpropAttrs, AvgPoolBackpropInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {AvgPool2DBackpropProgram} from '../avg_pool_backprop_gpu'; +import {MathBackendWebGL} from '../backend_webgl'; +import {assertNotComplex} from '../webgl_util'; + +export function avgPoolBackprop(args: { + inputs: AvgPoolBackpropInputs, + backend: MathBackendWebGL, + attrs: AvgPoolBackpropAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {dy, input} = inputs; + const x = input; + assertNotComplex([dy, input], 'avgPoolBackprop'); + const {filterSize, strides, pad} = attrs; + + const inputRank = input.shape.length; + const dyRank = dy.shape.length; + util.assert( + dyRank === 4, + () => `Error in avgPoolBackprop: dy must be rank 4 but got rank ` + + `${dyRank}.`); + util.assert( + inputRank === 4, + () => `Error in avgPoolBackprop: input must be rank 4 but got rank ` + + `${inputRank}.`); + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + 1 /* dilations */, pad); + const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); + return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); +} + +export const avgPoolBackpropConfig: KernelConfig = { + kernelName: AvgPoolBackprop, + backendName: 'webgl', + kernelFunc: avgPoolBackprop as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/MaxPool.ts b/tfjs-backend-webgl/src/kernels/MaxPool.ts new file mode 100644 index 00000000000..668832d9e57 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/MaxPool.ts @@ -0,0 +1,61 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {Pool2DProgram} from '../pool_gpu'; +import {assertNotComplex} from '../webgl_util'; + +export function maxPool(args: { + inputs: MaxPoolInputs, + backend: MathBackendWebGL, + attrs: MaxPoolAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + assertNotComplex(x, 'maxPool'); + const {filterSize, strides, pad, dimRoundingMode} = attrs; + const dilations = 1; + + const xRank = x.shape.length; + util.assert( + xRank === 4, + () => `Error in maxPool: input must be rank 4 but got rank ${ + x.shape.length}}.`); + util.assert( + backend_util.eitherStridesOrDilationsAreOne(strides, dilations), + () => 'Error in maxPool: Either strides or dilations must be 1. ' + + `Got strides ${strides} and dilations '${dilations}'`); + if (dimRoundingMode != null) { + util.assert( + util.isInt(pad as number), + () => `Error in maxPool: pad must be an integer when using, ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + dilations, pad, dimRoundingMode); + const program = new Pool2DProgram(convInfo, 'max', false); + return backend.runWebGLProgram(program, [x], x.dtype); +} + +export const maxPoolConfig: KernelConfig = { + kernelName: MaxPool, + backendName: 'webgl', + kernelFunc: maxPool as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts b/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts new file mode 100644 index 00000000000..3174e5a7f2a --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {backend_util, KernelConfig, KernelFunc, MaxPoolBackprop, MaxPoolBackpropAttrs, MaxPoolBackpropInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {MaxPool2DBackpropProgram} from '../max_pool_backprop_gpu'; +import {Pool2DProgram} from '../pool_gpu'; +import {assertNotComplex} from '../webgl_util'; + +export function maxPoolBackprop(args: { + inputs: MaxPoolBackpropInputs, + backend: MathBackendWebGL, + attrs: MaxPoolBackpropAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {dy, input, output} = inputs; + const x = input; + assertNotComplex([input, output], 'maxPoolBackprop'); + const {filterSize, strides, pad, dimRoundingMode} = attrs; + + const inputRank = input.shape.length; + const dyRank = dy.shape.length; + util.assert( + inputRank === dyRank, + () => `Rank of input (${inputRank}) does not match rank of dy ` + + `(${dyRank})`); + + util.assert( + dyRank === 4, + () => `Error in maxPoolBackprop: dy must be rank 4 but got rank ` + + `${dyRank}.`); + util.assert( + inputRank === 4, + () => `Error in maxPoolBackprop: input must be rank 4 but got rank ` + + `${inputRank}.`); + if (dimRoundingMode != null) { + util.assert( + util.isInt(pad as number), + () => `Error in maxPoolBackprop: pad must be an integer when using, ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + 1 /* dilations */, pad, dimRoundingMode); + const getPositions = true; + const maxPoolPositionsProgram = + new Pool2DProgram(convInfo, 'max', getPositions); + const maxPoolPositions: TensorInfo = + backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype); + + const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); + const result = backend.runWebGLProgram( + maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); + backend.disposeIntermediateTensorInfo(maxPoolPositions); + return result; +} + +export const maxPoolBackpropConfig: KernelConfig = { + kernelName: MaxPoolBackprop, + backendName: 'webgl', + kernelFunc: maxPoolBackprop as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index ccf46782f38..eab74e21a7a 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -17,12 +17,16 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {atan2Config} from './kernels/Atan2'; +import {avgPoolConfig} from './kernels/AvgPool'; +import {avgPoolBackpropConfig} from './kernels/AvgPoolBackprop'; import {batchNormConfig} from './kernels/BatchNorm'; import {cosConfig} from './kernels/Cos'; import {divConfig} from './kernels/Div'; import {flipLeftRightConfig} from './kernels/FlipLeftRight'; import {fromPixelsConfig} from './kernels/FromPixels'; import {maxConfig} from './kernels/Max'; +import {maxPoolConfig} from './kernels/MaxPool'; +import {maxPoolBackpropConfig} from './kernels/MaxPoolBackprop'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV3Config} from './kernels/NonMaxSuppressionV3'; import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4'; @@ -37,11 +41,28 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - atan2Config, batchNormConfig, cosConfig, maxConfig, flipLeftRightConfig, - fromPixelsConfig, divConfig, maxPoolWithArgmaxConfig, - nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, - nonMaxSuppressionV5Config, reshapeConfig, rotateWithOffsetConfig, sinConfig, - squareConfig, squaredDifferenceConfig, tanConfig, transposeConfig + atan2Config, + avgPoolConfig, + avgPoolBackpropConfig, + batchNormConfig, + cosConfig, + maxConfig, + flipLeftRightConfig, + fromPixelsConfig, + divConfig, + maxPoolConfig, + maxPoolBackpropConfig, + maxPoolWithArgmaxConfig, + nonMaxSuppressionV3Config, + nonMaxSuppressionV4Config, + nonMaxSuppressionV5Config, + reshapeConfig, + rotateWithOffsetConfig, + sinConfig, + squareConfig, + squaredDifferenceConfig, + tanConfig, + transposeConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-webgl/src/webgl_util.ts b/tfjs-backend-webgl/src/webgl_util.ts index 5832e8d9205..bb360b6c1a8 100644 --- a/tfjs-backend-webgl/src/webgl_util.ts +++ b/tfjs-backend-webgl/src/webgl_util.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {env, util} from '@tensorflow/tfjs-core'; +import {env, TensorInfo, util} from '@tensorflow/tfjs-core'; import {getWebGLContext} from './canvas_util'; import {getTextureConfig} from './tex_util'; @@ -673,3 +673,18 @@ export function isWebGLFenceEnabled(webGLVersion: number) { const isEnabled = (gl as any).fenceSync != null; return isEnabled; } + +export function assertNotComplex( + tensor: TensorInfo|TensorInfo[], opName: string): void { + if (!Array.isArray(tensor)) { + tensor = [tensor]; + } + tensor.forEach(t => { + if (t != null) { + util.assert( + t.dtype !== 'complex64', + () => `${opName} does not support complex64 tensors ` + + 'in the WebGL backend.'); + } + }); +} From 2ad8b0188ddacb7d9eb86f40f7d2a37e6a2fbf3c Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Mon, 28 Sep 2020 13:59:39 -0700 Subject: [PATCH 2/3] save --- tfjs-backend-webgl/src/kernels/AvgPool.ts | 12 ---------- .../src/kernels/AvgPoolBackprop.ts | 13 +--------- tfjs-backend-webgl/src/kernels/MaxPool.ts | 11 --------- .../src/kernels/MaxPoolBackprop.ts | 24 +------------------ 4 files changed, 2 insertions(+), 58 deletions(-) diff --git a/tfjs-backend-webgl/src/kernels/AvgPool.ts b/tfjs-backend-webgl/src/kernels/AvgPool.ts index 9eb2f51d023..db01863e296 100644 --- a/tfjs-backend-webgl/src/kernels/AvgPool.ts +++ b/tfjs-backend-webgl/src/kernels/AvgPool.ts @@ -36,18 +36,6 @@ export function avgPool(args: { () => 'Error in avgPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); - const xRank = x.shape.length; - util.assert( - xRank === 4, - () => `Error in avgPool: input must be rank 4 but got rank ${ - x.shape.length}}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in avgPool: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, dilations, pad, dimRoundingMode); diff --git a/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts b/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts index 046d3664411..a209d459fe5 100644 --- a/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts +++ b/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts @@ -14,7 +14,7 @@ * limitations under the License. * ============================================================================= */ -import {AvgPoolBackprop, AvgPoolBackpropAttrs, AvgPoolBackpropInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {AvgPoolBackprop, AvgPoolBackpropAttrs, AvgPoolBackpropInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; import {AvgPool2DBackpropProgram} from '../avg_pool_backprop_gpu'; import {MathBackendWebGL} from '../backend_webgl'; @@ -31,17 +31,6 @@ export function avgPoolBackprop(args: { assertNotComplex([dy, input], 'avgPoolBackprop'); const {filterSize, strides, pad} = attrs; - const inputRank = input.shape.length; - const dyRank = dy.shape.length; - util.assert( - dyRank === 4, - () => `Error in avgPoolBackprop: dy must be rank 4 but got rank ` + - `${dyRank}.`); - util.assert( - inputRank === 4, - () => `Error in avgPoolBackprop: input must be rank 4 but got rank ` + - `${inputRank}.`); - const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, 1 /* dilations */, pad); diff --git a/tfjs-backend-webgl/src/kernels/MaxPool.ts b/tfjs-backend-webgl/src/kernels/MaxPool.ts index 668832d9e57..2e2e4a68206 100644 --- a/tfjs-backend-webgl/src/kernels/MaxPool.ts +++ b/tfjs-backend-webgl/src/kernels/MaxPool.ts @@ -31,21 +31,10 @@ export function maxPool(args: { const {filterSize, strides, pad, dimRoundingMode} = attrs; const dilations = 1; - const xRank = x.shape.length; - util.assert( - xRank === 4, - () => `Error in maxPool: input must be rank 4 but got rank ${ - x.shape.length}}.`); util.assert( backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in maxPool: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, diff --git a/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts b/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts index 3174e5a7f2a..140b9455a02 100644 --- a/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts +++ b/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts @@ -14,7 +14,7 @@ * limitations under the License. * ============================================================================= */ -import {backend_util, KernelConfig, KernelFunc, MaxPoolBackprop, MaxPoolBackpropAttrs, MaxPoolBackpropInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, KernelFunc, MaxPoolBackprop, MaxPoolBackpropAttrs, MaxPoolBackpropInputs, TensorInfo} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; import {MaxPool2DBackpropProgram} from '../max_pool_backprop_gpu'; @@ -32,28 +32,6 @@ export function maxPoolBackprop(args: { assertNotComplex([input, output], 'maxPoolBackprop'); const {filterSize, strides, pad, dimRoundingMode} = attrs; - const inputRank = input.shape.length; - const dyRank = dy.shape.length; - util.assert( - inputRank === dyRank, - () => `Rank of input (${inputRank}) does not match rank of dy ` + - `(${dyRank})`); - - util.assert( - dyRank === 4, - () => `Error in maxPoolBackprop: dy must be rank 4 but got rank ` + - `${dyRank}.`); - util.assert( - inputRank === 4, - () => `Error in maxPoolBackprop: input must be rank 4 but got rank ` + - `${inputRank}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in maxPoolBackprop: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); From 117420070599eff6eb2d16682f02ec8b7398d254 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Wed, 30 Sep 2020 09:14:17 -0700 Subject: [PATCH 3/3] address comments --- tfjs-backend-webgl/src/kernels/AvgPool.ts | 9 +++-- tfjs-backend-webgl/src/kernels/Identity.ts | 35 +++++++++++++++++++ tfjs-backend-webgl/src/kernels/MaxPool.ts | 9 +++-- .../src/register_all_kernels.ts | 6 ++-- 4 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 tfjs-backend-webgl/src/kernels/Identity.ts diff --git a/tfjs-backend-webgl/src/kernels/AvgPool.ts b/tfjs-backend-webgl/src/kernels/AvgPool.ts index db01863e296..b9fcd0c48c7 100644 --- a/tfjs-backend-webgl/src/kernels/AvgPool.ts +++ b/tfjs-backend-webgl/src/kernels/AvgPool.ts @@ -19,6 +19,7 @@ import {AvgPool, AvgPoolAttrs, AvgPoolInputs, backend_util, KernelConfig, Kernel import {MathBackendWebGL} from '../backend_webgl'; import {Pool2DProgram} from '../pool_gpu'; import {assertNotComplex} from '../webgl_util'; +import {identity} from './Identity'; export function avgPool(args: { inputs: AvgPoolInputs, @@ -39,8 +40,12 @@ export function avgPool(args: { const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, dilations, pad, dimRoundingMode); - const program = new Pool2DProgram(convInfo, 'avg', false); - return backend.runWebGLProgram(program, [x], 'float32'); + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + util.arraysEqual(convInfo.inShape, convInfo.outShape)) { + return identity({inputs: {x}, backend}); + } + const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false); + return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); } export const avgPoolConfig: KernelConfig = { diff --git a/tfjs-backend-webgl/src/kernels/Identity.ts b/tfjs-backend-webgl/src/kernels/Identity.ts new file mode 100644 index 00000000000..b3c8eb8f7e6 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Identity.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Identity, IdentityInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; +import {MathBackendWebGL} from '../backend_webgl'; + +export function identity( + args: {inputs: IdentityInputs, backend: MathBackendWebGL}): TensorInfo { + const {inputs, backend} = args; + const {x} = inputs; + + backend.incRef(x.dataId); + + return {dataId: x.dataId, shape: x.shape, dtype: x.dtype}; +} + +export const identityConfig: KernelConfig = { + kernelName: Identity, + backendName: 'webgl', + kernelFunc: identity as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/MaxPool.ts b/tfjs-backend-webgl/src/kernels/MaxPool.ts index 2e2e4a68206..fd69af1de68 100644 --- a/tfjs-backend-webgl/src/kernels/MaxPool.ts +++ b/tfjs-backend-webgl/src/kernels/MaxPool.ts @@ -19,6 +19,7 @@ import {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolIn import {MathBackendWebGL} from '../backend_webgl'; import {Pool2DProgram} from '../pool_gpu'; import {assertNotComplex} from '../webgl_util'; +import {identity} from './Identity'; export function maxPool(args: { inputs: MaxPoolInputs, @@ -39,8 +40,12 @@ export function maxPool(args: { const convInfo = backend_util.computePool2DInfo( x.shape as [number, number, number, number], filterSize, strides, dilations, pad, dimRoundingMode); - const program = new Pool2DProgram(convInfo, 'max', false); - return backend.runWebGLProgram(program, [x], x.dtype); + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + util.arraysEqual(convInfo.inShape, convInfo.outShape)) { + return identity({inputs: {x}, backend}); + } + const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false); + return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } export const maxPoolConfig: KernelConfig = { diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index eab74e21a7a..48161cc1f2b 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -24,6 +24,7 @@ import {cosConfig} from './kernels/Cos'; import {divConfig} from './kernels/Div'; import {flipLeftRightConfig} from './kernels/FlipLeftRight'; import {fromPixelsConfig} from './kernels/FromPixels'; +import {identityConfig} from './kernels/Identity'; import {maxConfig} from './kernels/Max'; import {maxPoolConfig} from './kernels/MaxPool'; import {maxPoolBackpropConfig} from './kernels/MaxPoolBackprop'; @@ -46,10 +47,11 @@ const kernelConfigs: KernelConfig[] = [ avgPoolBackpropConfig, batchNormConfig, cosConfig, - maxConfig, + divConfig, flipLeftRightConfig, fromPixelsConfig, - divConfig, + identityConfig, + maxConfig, maxPoolConfig, maxPoolBackpropConfig, maxPoolWithArgmaxConfig,