diff --git a/tfjs-backend-webgpu/src/kernels/LRNGrad.ts b/tfjs-backend-webgpu/src/kernels/LRNGrad.ts new file mode 100644 index 00000000000..63fcfccf2c0 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/LRNGrad.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, LRNGrad, LRNGradAttrs, LRNGradInputs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {LRNGradProgram} from '../lrn_grad_webgpu'; + +export function lrnGrad( + args: {inputs: LRNGradInputs, backend: WebGPUBackend, attrs: LRNGradAttrs}): + TensorInfo { + const {inputs, backend, attrs} = args; + const {x, y, dy} = inputs; + const {depthRadius, bias, alpha, beta} = attrs; + + const program = new LRNGradProgram(x.shape); + const uniformData = [ + {type: 'int32', data: [depthRadius]}, {type: 'float32', data: [bias]}, + {type: 'float32', data: [alpha]}, {type: 'float32', data: [beta]} + ]; + const res = + backend.runWebGPUProgram(program, [x, y, dy], x.dtype, uniformData); + + return res; +} + +export const lrnGradConfig: KernelConfig = { + kernelName: LRNGrad, + backendName: 'webgpu', + kernelFunc: lrnGrad as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/lrn_grad_webgpu.ts b/tfjs-backend-webgpu/src/lrn_grad_webgpu.ts new file mode 100644 index 00000000000..b065477f5bf --- /dev/null +++ b/tfjs-backend-webgpu/src/lrn_grad_webgpu.ts @@ -0,0 +1,93 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class LRNGradProgram implements WebGPUProgram { + outputShape: number[] = []; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + variableNames = ['inputImage', 'outputImage', 'dy']; + uniforms = 'depthRadius : i32, bias : f32, alpha : f32, beta : f32,'; + workgroupSize: [number, number, number] = [64, 1, 1]; + size = true; + + constructor(inputShape: number[]) { + this.outputShape = inputShape; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + this.shaderKey = 'lrn_grad'; + } + + getUserCode(): string { + const userCode = ` + ${main('index')} { + if (index < uniforms.size) { + let coords = getOutputCoords(); + let b = coords[0]; + let r = coords[1]; + let c = coords[2]; + + let MIN_DEPTH_BEGIN = 0; + let MAX_DEPTH_END = uniforms.outShape[3]; + var result = 0.0; + for (var d = MIN_DEPTH_BEGIN; d < MAX_DEPTH_END; d++) { + let depthBegin = max(MIN_DEPTH_BEGIN, d - uniforms.depthRadius); + let depthEnd = min(MAX_DEPTH_END, d + uniforms.depthRadius + 1); + + var norm = 0.0; + for (var k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; k++) { + if (k < depthBegin) { + continue; + } else if (k >= depthBegin && k < depthEnd) { + norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k); + } else { + break; + } + } + + norm = uniforms.alpha * norm + uniforms.bias; + + for (var k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; k++) { + if (k < depthBegin) { + continue; + } else if (k >= depthBegin && k < depthEnd) { + var dyi = -2.0 * uniforms.alpha * uniforms.beta + * getInputImage(b, r, c, k) * getOutputImage(b, r, c, d) / norm; + if (k == d) { + dyi += pow(norm, -1.0 * uniforms.beta); + } + if (k == coords[3]) { + dyi *= getDy(b, r, c, d); + result += dyi; + } + } else { + break; + } + } + } + + setOutputAtIndex(index, result); + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 8b5df62b6c3..c339ac07565 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -95,6 +95,7 @@ import {logicalAndConfig} from './kernels/LogicalAnd'; import {logicalNotConfig} from './kernels/LogicalNot'; import {logicalOrConfig} from './kernels/LogicalOr'; import {lrnConfig} from './kernels/LRN'; +import {lrnGradConfig} from './kernels/LRNGrad'; import {maxConfig} from './kernels/Max'; import {maximumConfig} from './kernels/Maximum'; import {maxPoolConfig} from './kernels/MaxPool'; @@ -243,6 +244,7 @@ const kernelConfigs: KernelConfig[] = [ logicalNotConfig, logicalOrConfig, lrnConfig, + lrnGradConfig, maxConfig, maximumConfig, maxPoolConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index e36c9a107eb..ec945142622 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -115,12 +115,6 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient' // gradient function not found. ] }, - { - startsWith: 'localResponseNormalization ', - excludes: [ - 'gradient', // Not yet implemented. - ] - }, { startsWith: 'matmul', excludes: [