diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index 18c238c6c1b..d6a41def911 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -33,7 +33,7 @@ import {AddNProgram} from './addn_gpu'; import {AddNPackedProgram} from './addn_packed_gpu'; import {ArgMinMaxProgram} from './argminmax_gpu'; import {ArgMinMaxPackedProgram} from './argminmax_packed_gpu'; -import {AvgPool2DBackpropProgram, AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu'; +import {AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu'; import * as binaryop_complex_gpu from './binaryop_complex_gpu'; import {BinaryOpComplexProgram} from './binaryop_complex_gpu'; import * as binaryop_gpu from './binaryop_gpu'; @@ -73,14 +73,14 @@ import {Im2ColPackedProgram} from './im2col_packed_gpu'; import {LRNProgram} from './lrn_gpu'; import {LRNGradProgram} from './lrn_grad_gpu'; import {LRNPackedProgram} from './lrn_packed_gpu'; -import {MaxPool2DBackpropProgram, MaxPool3DBackpropProgram} from './max_pool_backprop_gpu'; +import {MaxPool3DBackpropProgram} from './max_pool_backprop_gpu'; import {MatMulPackedProgram} from './mulmat_packed_gpu'; import {MultinomialProgram} from './multinomial_gpu'; import {OneHotProgram} from './onehot_gpu'; import {PackProgram} from './pack_gpu'; import {PadProgram} from './pad_gpu'; import {PadPackedProgram} from './pad_packed_gpu'; -import {Pool2DProgram, Pool3DProgram} from './pool_gpu'; +import {Pool3DProgram} from './pool_gpu'; import {ReduceProgram} from './reduce_gpu'; import {ReshapePackedProgram} from './reshape_packed_gpu'; import {ResizeBilinearBackpropProgram} from './resize_bilinear_backprop_gpu'; @@ -2097,38 +2097,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [x, dy]); } - maxPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { - const program = new Pool2DProgram(convInfo, 'max', false); - return this.compileAndRun(program, [x]); - } - - avgPool(x: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { - const program = new Pool2DProgram(convInfo, 'avg', false); - return this.compileAndRun(program, [x], 'float32'); - } - - maxPoolBackprop( - dy: Tensor4D, x: Tensor4D, y: Tensor4D, - convInfo: backend_util.Conv2DInfo): Tensor4D { - const getPositions = true; - const maxPoolPositionsProgram = - new Pool2DProgram(convInfo, 'max', getPositions); - const maxPoolPositions: Tensor4D = - this.compileAndRun(maxPoolPositionsProgram, [x]); - - const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); - const result = this.compileAndRun( - maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); - maxPoolPositions.dispose(); - return result as Tensor4D; - } - - avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: backend_util.Conv2DInfo): - Tensor4D { - const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); - return this.compileAndRun(avgPoolBackpropProgram, [dy], x.dtype); - } - cast(x: T, dtype: DataType): T { return backend_util.castTensor(x, dtype, this); } diff --git a/tfjs-backend-webgl/src/kernels/AvgPool.ts b/tfjs-backend-webgl/src/kernels/AvgPool.ts new file mode 100644 index 00000000000..b9fcd0c48c7 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/AvgPool.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {AvgPool, AvgPoolAttrs, AvgPoolInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {Pool2DProgram} from '../pool_gpu'; +import {assertNotComplex} from '../webgl_util'; +import {identity} from './Identity'; + +export function avgPool(args: { + inputs: AvgPoolInputs, + backend: MathBackendWebGL, + attrs: AvgPoolAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + assertNotComplex(x, 'avgPool'); + const {filterSize, strides, pad, dimRoundingMode} = attrs; + const dilations = 1; + + util.assert( + backend_util.eitherStridesOrDilationsAreOne(strides, dilations), + () => 'Error in avgPool: Either strides or dilations must be 1. ' + + `Got strides ${strides} and dilations '${dilations}'`); + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + dilations, pad, dimRoundingMode); + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + util.arraysEqual(convInfo.inShape, convInfo.outShape)) { + return identity({inputs: {x}, backend}); + } + const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false); + return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); +} + +export const avgPoolConfig: KernelConfig = { + kernelName: AvgPool, + backendName: 'webgl', + kernelFunc: avgPool as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts b/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts new file mode 100644 index 00000000000..a209d459fe5 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/AvgPoolBackprop.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {AvgPoolBackprop, AvgPoolBackpropAttrs, AvgPoolBackpropInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; + +import {AvgPool2DBackpropProgram} from '../avg_pool_backprop_gpu'; +import {MathBackendWebGL} from '../backend_webgl'; +import {assertNotComplex} from '../webgl_util'; + +export function avgPoolBackprop(args: { + inputs: AvgPoolBackpropInputs, + backend: MathBackendWebGL, + attrs: AvgPoolBackpropAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {dy, input} = inputs; + const x = input; + assertNotComplex([dy, input], 'avgPoolBackprop'); + const {filterSize, strides, pad} = attrs; + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + 1 /* dilations */, pad); + const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); + return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); +} + +export const avgPoolBackpropConfig: KernelConfig = { + kernelName: AvgPoolBackprop, + backendName: 'webgl', + kernelFunc: avgPoolBackprop as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/Identity.ts b/tfjs-backend-webgl/src/kernels/Identity.ts new file mode 100644 index 00000000000..b3c8eb8f7e6 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Identity.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Identity, IdentityInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; +import {MathBackendWebGL} from '../backend_webgl'; + +export function identity( + args: {inputs: IdentityInputs, backend: MathBackendWebGL}): TensorInfo { + const {inputs, backend} = args; + const {x} = inputs; + + backend.incRef(x.dataId); + + return {dataId: x.dataId, shape: x.shape, dtype: x.dtype}; +} + +export const identityConfig: KernelConfig = { + kernelName: Identity, + backendName: 'webgl', + kernelFunc: identity as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/MaxPool.ts b/tfjs-backend-webgl/src/kernels/MaxPool.ts new file mode 100644 index 00000000000..fd69af1de68 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/MaxPool.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {Pool2DProgram} from '../pool_gpu'; +import {assertNotComplex} from '../webgl_util'; +import {identity} from './Identity'; + +export function maxPool(args: { + inputs: MaxPoolInputs, + backend: MathBackendWebGL, + attrs: MaxPoolAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + assertNotComplex(x, 'maxPool'); + const {filterSize, strides, pad, dimRoundingMode} = attrs; + const dilations = 1; + + util.assert( + backend_util.eitherStridesOrDilationsAreOne(strides, dilations), + () => 'Error in maxPool: Either strides or dilations must be 1. ' + + `Got strides ${strides} and dilations '${dilations}'`); + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + dilations, pad, dimRoundingMode); + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + util.arraysEqual(convInfo.inShape, convInfo.outShape)) { + return identity({inputs: {x}, backend}); + } + const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false); + return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); +} + +export const maxPoolConfig: KernelConfig = { + kernelName: MaxPool, + backendName: 'webgl', + kernelFunc: maxPool as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts b/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts new file mode 100644 index 00000000000..140b9455a02 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/MaxPoolBackprop.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {backend_util, KernelConfig, KernelFunc, MaxPoolBackprop, MaxPoolBackpropAttrs, MaxPoolBackpropInputs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {MaxPool2DBackpropProgram} from '../max_pool_backprop_gpu'; +import {Pool2DProgram} from '../pool_gpu'; +import {assertNotComplex} from '../webgl_util'; + +export function maxPoolBackprop(args: { + inputs: MaxPoolBackpropInputs, + backend: MathBackendWebGL, + attrs: MaxPoolBackpropAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {dy, input, output} = inputs; + const x = input; + assertNotComplex([input, output], 'maxPoolBackprop'); + const {filterSize, strides, pad, dimRoundingMode} = attrs; + + const convInfo = backend_util.computePool2DInfo( + x.shape as [number, number, number, number], filterSize, strides, + 1 /* dilations */, pad, dimRoundingMode); + const getPositions = true; + const maxPoolPositionsProgram = + new Pool2DProgram(convInfo, 'max', getPositions); + const maxPoolPositions: TensorInfo = + backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype); + + const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); + const result = backend.runWebGLProgram( + maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); + backend.disposeIntermediateTensorInfo(maxPoolPositions); + return result; +} + +export const maxPoolBackpropConfig: KernelConfig = { + kernelName: MaxPoolBackprop, + backendName: 'webgl', + kernelFunc: maxPoolBackprop as {} as KernelFunc +}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index ccf46782f38..48161cc1f2b 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -17,12 +17,17 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {atan2Config} from './kernels/Atan2'; +import {avgPoolConfig} from './kernels/AvgPool'; +import {avgPoolBackpropConfig} from './kernels/AvgPoolBackprop'; import {batchNormConfig} from './kernels/BatchNorm'; import {cosConfig} from './kernels/Cos'; import {divConfig} from './kernels/Div'; import {flipLeftRightConfig} from './kernels/FlipLeftRight'; import {fromPixelsConfig} from './kernels/FromPixels'; +import {identityConfig} from './kernels/Identity'; import {maxConfig} from './kernels/Max'; +import {maxPoolConfig} from './kernels/MaxPool'; +import {maxPoolBackpropConfig} from './kernels/MaxPoolBackprop'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV3Config} from './kernels/NonMaxSuppressionV3'; import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4'; @@ -37,11 +42,29 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - atan2Config, batchNormConfig, cosConfig, maxConfig, flipLeftRightConfig, - fromPixelsConfig, divConfig, maxPoolWithArgmaxConfig, - nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, - nonMaxSuppressionV5Config, reshapeConfig, rotateWithOffsetConfig, sinConfig, - squareConfig, squaredDifferenceConfig, tanConfig, transposeConfig + atan2Config, + avgPoolConfig, + avgPoolBackpropConfig, + batchNormConfig, + cosConfig, + divConfig, + flipLeftRightConfig, + fromPixelsConfig, + identityConfig, + maxConfig, + maxPoolConfig, + maxPoolBackpropConfig, + maxPoolWithArgmaxConfig, + nonMaxSuppressionV3Config, + nonMaxSuppressionV4Config, + nonMaxSuppressionV5Config, + reshapeConfig, + rotateWithOffsetConfig, + sinConfig, + squareConfig, + squaredDifferenceConfig, + tanConfig, + transposeConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-webgl/src/webgl_util.ts b/tfjs-backend-webgl/src/webgl_util.ts index 5832e8d9205..bb360b6c1a8 100644 --- a/tfjs-backend-webgl/src/webgl_util.ts +++ b/tfjs-backend-webgl/src/webgl_util.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {env, util} from '@tensorflow/tfjs-core'; +import {env, TensorInfo, util} from '@tensorflow/tfjs-core'; import {getWebGLContext} from './canvas_util'; import {getTextureConfig} from './tex_util'; @@ -673,3 +673,18 @@ export function isWebGLFenceEnabled(webGLVersion: number) { const isEnabled = (gl as any).fenceSync != null; return isEnabled; } + +export function assertNotComplex( + tensor: TensorInfo|TensorInfo[], opName: string): void { + if (!Array.isArray(tensor)) { + tensor = [tensor]; + } + tensor.forEach(t => { + if (t != null) { + util.assert( + t.dtype !== 'complex64', + () => `${opName} does not support complex64 tensors ` + + 'in the WebGL backend.'); + } + }); +}