From 3d5fba73664602bf939e095bfe1749fc6805e645 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Thu, 2 Mar 2023 12:08:11 +0800 Subject: [PATCH] [WebGPU] support AvgPool3DGrad kernel --- ..._webgpu.ts => avg_pool_backprop_webgpu.ts} | 78 ++++++++++++++++++- .../src/kernels/AvgPool3DGrad.ts | 71 +++++++++++++++++ .../src/kernels/AvgPoolGrad.ts | 2 +- .../src/register_all_kernels.ts | 2 + tfjs-backend-webgpu/src/setup_test.ts | 1 - 5 files changed, 151 insertions(+), 3 deletions(-) rename tfjs-backend-webgpu/src/{avg_pool2d_backprop_webgpu.ts => avg_pool_backprop_webgpu.ts} (54%) create mode 100644 tfjs-backend-webgpu/src/kernels/AvgPool3DGrad.ts diff --git a/tfjs-backend-webgpu/src/avg_pool2d_backprop_webgpu.ts b/tfjs-backend-webgpu/src/avg_pool_backprop_webgpu.ts similarity index 54% rename from tfjs-backend-webgpu/src/avg_pool2d_backprop_webgpu.ts rename to tfjs-backend-webgpu/src/avg_pool_backprop_webgpu.ts index 9ea0424d666..3e70c96f91e 100644 --- a/tfjs-backend-webgpu/src/avg_pool2d_backprop_webgpu.ts +++ b/tfjs-backend-webgpu/src/avg_pool_backprop_webgpu.ts @@ -39,7 +39,7 @@ export class AvgPool2DBackpropProgram implements WebGPUProgram { this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workgroupSize); - this.shaderKey = `avg_pool2d_backprop`; + this.shaderKey = `avgPool2DBackprop`; } getUserCode(): string { @@ -85,3 +85,79 @@ export class AvgPool2DBackpropProgram implements WebGPUProgram { return userCode; } } + +export class AvgPool3DBackpropProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + variableNames = ['dy']; + uniforms = `strides : vec3, pads : vec3, filterDims : vec3, + outDepth : i32, outHeight : i32, outWidth : i32, avgMultiplier : f32,`; + workgroupSize: [number, number, number] = [64, 1, 1]; + size = true; + + constructor(convInfo: backend_util.Conv3DInfo) { + this.outputShape = convInfo.inShape; + + this.dispatchLayout = flatDispatchLayout(this.outputShape); + + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + + this.shaderKey = `avgPool3DBackprop`; + } + + getUserCode(): string { + const userCode = ` + ${main('index')} { + if (index < uniforms.size) { + let coords = getCoordsFromIndex(index); + let batch = coords.x; + let ch = coords.u; + + let dyCorner = vec3(coords.y, coords.z, coords.w) - uniforms.pads; + let dyDCorner = dyCorner.x; + let dyRCorner = dyCorner.y; + let dyCCorner = dyCorner.z; + + // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get + // dx(xD, xR, xC, ch). + // ? = to be determined. : = across all values in that axis. + var dotProd = 0.0; + for (var wD = 0; wD < uniforms.filterDims[0]; wD++) { + let dyD = f32(dyDCorner + wD) / f32(uniforms.strides[0]); + + if (dyD < 0.0 || dyD >= f32(uniforms.outDepth) || fract(dyD) > 0.0) { + continue; + } + let idyD = i32(dyD); + + for (var wR = 0; wR < uniforms.filterDims[1]; wR++) { + let dyR = f32(dyRCorner + wR) / f32(uniforms.strides[1]); + + if (dyR < 0.0 || dyR >= f32(uniforms.outHeight) || fract(dyR) > 0.0) { + continue; + } + let idyR = i32(dyR); + + for (var wC = 0; wC < uniforms.filterDims[2]; wC++) { + let dyC = f32(dyCCorner + wC) / f32(uniforms.strides[2]); + + if (dyC < 0.0 || dyC >= f32(uniforms.outWidth) || fract(dyC) > 0.0) { + continue; + } + let idyC = i32(dyC); + + let dyValue = getDy(batch, idyD, idyR, idyC, ch); + dotProd += dyValue * uniforms.avgMultiplier; + } + } + } + setOutputAtIndex(index, dotProd); + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/kernels/AvgPool3DGrad.ts b/tfjs-backend-webgpu/src/kernels/AvgPool3DGrad.ts new file mode 100644 index 00000000000..ae9280fec61 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/AvgPool3DGrad.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {AvgPool3DGrad, AvgPool3DGradAttrs, AvgPool3DGradInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; + +import {AvgPool3DBackpropProgram} from '../avg_pool_backprop_webgpu'; +import {WebGPUBackend} from '../backend_webgpu'; + +export function avgPool3DGrad(args: { + inputs: AvgPool3DGradInputs, + backend: WebGPUBackend, + attrs: AvgPool3DGradAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {dy, input} = inputs; + const x = input; + const {filterSize, strides, pad, dimRoundingMode} = attrs; + + const convInfo = backend_util.computePool3DInfo( + x.shape as [number, number, number, number, number], filterSize, strides, + 1 /* dilations */, pad, dimRoundingMode); + const program = new AvgPool3DBackpropProgram(convInfo); + const avgMultiplier = + 1 / (convInfo.filterDepth * convInfo.filterHeight * convInfo.filterWidth); + const uniformData = [ + { + type: 'int32', + data: [convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth] + }, + { + type: 'int32', + data: [ + convInfo.effectiveFilterDepth - 1 - convInfo.padInfo.front, + convInfo.effectiveFilterHeight - 1 - convInfo.padInfo.top, + convInfo.effectiveFilterWidth - 1 - convInfo.padInfo.left + ] + }, + { + type: 'int32', + data: [ + convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight, + convInfo.effectiveFilterWidth + ] + }, + {type: 'int32', data: [convInfo.outDepth]}, + {type: 'int32', data: [convInfo.outHeight]}, + {type: 'int32', data: [convInfo.outWidth]}, + {type: 'float32', data: [avgMultiplier]} + ]; + return backend.runWebGPUProgram(program, [dy], x.dtype, uniformData); +} + +export const avgPool3DGradConfig: KernelConfig = { + kernelName: AvgPool3DGrad, + backendName: 'webgpu', + kernelFunc: avgPool3DGrad as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/kernels/AvgPoolGrad.ts b/tfjs-backend-webgpu/src/kernels/AvgPoolGrad.ts index b838f2db0e6..b021bcff839 100644 --- a/tfjs-backend-webgpu/src/kernels/AvgPoolGrad.ts +++ b/tfjs-backend-webgpu/src/kernels/AvgPoolGrad.ts @@ -17,7 +17,7 @@ import {AvgPoolGrad, AvgPoolGradAttrs, AvgPoolGradInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; -import {AvgPool2DBackpropProgram} from '../avg_pool2d_backprop_webgpu'; +import {AvgPool2DBackpropProgram} from '../avg_pool_backprop_webgpu'; import {WebGPUBackend} from '../backend_webgpu'; import {assertNotComplex} from '../webgpu_util'; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 4eda3ab7ea7..8d922b95dd4 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -33,6 +33,7 @@ import {atan2Config} from './kernels/Atan2'; import {atanhConfig} from './kernels/Atanh'; import {avgPoolConfig} from './kernels/AvgPool'; import {avgPool3DConfig} from './kernels/AvgPool3D'; +import {avgPool3DGradConfig} from './kernels/AvgPool3DGrad'; import {avgPoolGradConfig} from './kernels/AvgPoolGrad'; import {batchMatMulConfig} from './kernels/BatchMatMul'; import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND'; @@ -189,6 +190,7 @@ const kernelConfigs: KernelConfig[] = [ atanhConfig, avgPoolConfig, avgPool3DConfig, + avgPool3DGradConfig, avgPoolGradConfig, batchMatMulConfig, batchToSpaceNDConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 8d71390edb9..a1fd48c0416 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -75,7 +75,6 @@ const TEST_FILTERS: TestFilter[] = [ { include: ' webgpu ', excludes: [ - 'avgPool3dBackprop ', 'raggedGather ', 'raggedRange ', 'raggedTensorToTensor ',