diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index bfdde48a331..f6ca1048231 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -21,6 +21,7 @@ export enum BinaryOpType { COMPLEX_MULTIPLY_IMAG, COMPLEX_MULTIPLY_REAL, DIV, + ELU_DER, EQUAL, GREATER, GREATER_EQUAL, @@ -59,6 +60,9 @@ const ADD = 'return a + b;'; const COMPLEX_MULTIPLY_REAL = 'return areal * breal - aimag * bimag;'; const COMPLEX_MULTIPLY_IMAG = 'return areal * bimag + aimag * breal;'; const DIV = 'return a / b;'; +const ELU_DER = 'return select(a * (b + 1.0), a, b >= 0.);'; +const ELU_DER_VEC4 = + 'return select(a * (b + vec4(1.0)), a, b >= vec4(0.));'; const EQUAL = 'return f32(a == b);'; const EQUAL_VEC4 = 'return vec4(a == b);'; const GREATER = 'return f32(a > b);'; @@ -223,6 +227,8 @@ export function getBinaryOpString( return COMPLEX_MULTIPLY_REAL; case BinaryOpType.DIV: return DIV; + case BinaryOpType.ELU_DER: + return useVec4 ? ELU_DER_VEC4 : ELU_DER; case BinaryOpType.EQUAL: return useVec4 ? EQUAL_VEC4 : EQUAL; case BinaryOpType.GREATER: diff --git a/tfjs-backend-webgpu/src/kernels/EluGrad.ts b/tfjs-backend-webgpu/src/kernels/EluGrad.ts new file mode 100644 index 00000000000..ed5b387b5a9 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/EluGrad.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {EluGrad, EluGradInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {BinaryOpType} from '../binary_op_util'; +import {BinaryOpProgram} from '../binary_op_webgpu'; + +export const eluGrad = + (args: {inputs: EluGradInputs, backend: WebGPUBackend}): TensorInfo => { + const {inputs, backend} = args; + const {dy, y} = inputs; + + const program = + new BinaryOpProgram(BinaryOpType.ELU_DER, dy.shape, y.shape); + return backend.runWebGPUProgram(program, [dy, y], dy.dtype); + }; + +export const eluGradConfig: KernelConfig = { + kernelName: EluGrad, + backendName: 'webgpu', + kernelFunc: eluGrad as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 91fce062269..938b53f54d4 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -58,6 +58,7 @@ import {diagConfig} from './kernels/Diag'; import {dilation2DConfig} from './kernels/Dilation2D'; import {einsumConfig} from './kernels/Einsum'; import {eluConfig} from './kernels/Elu'; +import {eluGradConfig} from './kernels/EluGrad'; import {equalConfig} from './kernels/Equal'; import {erfConfig} from './kernels/Erf'; import {expConfig} from './kernels/Exp'; @@ -201,6 +202,7 @@ const kernelConfigs: KernelConfig[] = [ dilation2DConfig, einsumConfig, eluConfig, + eluGradConfig, equalConfig, erfConfig, expConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index a88393ab4a6..01c07652bb1 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -81,13 +81,6 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient' // gradient function not found. ] }, - { - startsWith: 'elu ', - excludes: [ - 'derivative', // gradient function not found. - 'gradient' // gradient function not found. - ] - }, { startsWith: 'exp ', excludes: [