From dcbcfaa2d488b632d0fa7cba70c4b44b4dd05988 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Thu, 22 Dec 2022 15:08:00 +0800 Subject: [PATCH 1/3] webgpu: support EluGrad kernel --- tfjs-backend-webgpu/src/binary_op_util.ts | 11 ++++++ tfjs-backend-webgpu/src/kernels/EluGrad.ts | 38 +++++++++++++++++++ .../src/register_all_kernels.ts | 2 + tfjs-backend-webgpu/src/setup_test.ts | 7 ---- 4 files changed, 51 insertions(+), 7 deletions(-) create mode 100644 tfjs-backend-webgpu/src/kernels/EluGrad.ts diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index bfdde48a331..631dbd2310b 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -21,6 +21,7 @@ export enum BinaryOpType { COMPLEX_MULTIPLY_IMAG, COMPLEX_MULTIPLY_REAL, DIV, + ELU_DER, EQUAL, GREATER, GREATER_EQUAL, @@ -59,6 +60,14 @@ const ADD = 'return a + b;'; const COMPLEX_MULTIPLY_REAL = 'return areal * breal - aimag * bimag;'; const COMPLEX_MULTIPLY_IMAG = 'return areal * bimag + aimag * breal;'; const DIV = 'return a / b;'; +const ELU_DER = ` + let bGTEZero = f32(b > 0.); + return (bGTEZero * a) + (1.0 - bGTEZero) * (a * (b + 1.0)); +`; +const ELU_DER_VEC4 = ` + let bGTEZero = vec4(b >= vec4(0.)); + return (bGTEZero * a) + (vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))); +`; const EQUAL = 'return f32(a == b);'; const EQUAL_VEC4 = 'return vec4(a == b);'; const GREATER = 'return f32(a > b);'; @@ -223,6 +232,8 @@ export function getBinaryOpString( return COMPLEX_MULTIPLY_REAL; case BinaryOpType.DIV: return DIV; + case BinaryOpType.ELU_DER: + return useVec4 ? ELU_DER_VEC4 : ELU_DER; case BinaryOpType.EQUAL: return useVec4 ? EQUAL_VEC4 : EQUAL; case BinaryOpType.GREATER: diff --git a/tfjs-backend-webgpu/src/kernels/EluGrad.ts b/tfjs-backend-webgpu/src/kernels/EluGrad.ts new file mode 100644 index 00000000000..a9c72429ba9 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/EluGrad.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {EluGrad, EluGradInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {BinaryOpType} from '../binary_op_util'; +import {BinaryOpProgram} from '../binary_op_webgpu'; + +export const eluGrad = + (args: {inputs: EluGradInputs, backend: WebGPUBackend}): TensorInfo => { + const {inputs, backend} = args; + const {dy, y} = inputs; + + const program = + new BinaryOpProgram(BinaryOpType.ELU_DER, dy.shape, y.shape); + return backend.runWebGPUProgram(program, [dy, y], dy.dtype); + }; + +export const eluGradConfig: KernelConfig = { + kernelName: EluGrad, + backendName: 'webgpu', + kernelFunc: eluGrad as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 91fce062269..938b53f54d4 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -58,6 +58,7 @@ import {diagConfig} from './kernels/Diag'; import {dilation2DConfig} from './kernels/Dilation2D'; import {einsumConfig} from './kernels/Einsum'; import {eluConfig} from './kernels/Elu'; +import {eluGradConfig} from './kernels/EluGrad'; import {equalConfig} from './kernels/Equal'; import {erfConfig} from './kernels/Erf'; import {expConfig} from './kernels/Exp'; @@ -201,6 +202,7 @@ const kernelConfigs: KernelConfig[] = [ dilation2DConfig, einsumConfig, eluConfig, + eluGradConfig, equalConfig, erfConfig, expConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index a88393ab4a6..01c07652bb1 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -81,13 +81,6 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient' // gradient function not found. ] }, - { - startsWith: 'elu ', - excludes: [ - 'derivative', // gradient function not found. - 'gradient' // gradient function not found. - ] - }, { startsWith: 'exp ', excludes: [ From 783269b433e6815f43ffb80bab750bd87e810387 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Fri, 23 Dec 2022 10:06:58 +0800 Subject: [PATCH 2/3] Address Yang's comments --- tfjs-backend-webgpu/src/binary_op_util.ts | 11 +--- tfjs-backend-webgpu/src/kernels/EluGrad.ts | 76 +++++++++++----------- 2 files changed, 41 insertions(+), 46 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index 631dbd2310b..f6ca1048231 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -60,14 +60,9 @@ const ADD = 'return a + b;'; const COMPLEX_MULTIPLY_REAL = 'return areal * breal - aimag * bimag;'; const COMPLEX_MULTIPLY_IMAG = 'return areal * bimag + aimag * breal;'; const DIV = 'return a / b;'; -const ELU_DER = ` - let bGTEZero = f32(b > 0.); - return (bGTEZero * a) + (1.0 - bGTEZero) * (a * (b + 1.0)); -`; -const ELU_DER_VEC4 = ` - let bGTEZero = vec4(b >= vec4(0.)); - return (bGTEZero * a) + (vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))); -`; +const ELU_DER = 'return select(a * (b + 1.0), a, b >= 0.);'; +const ELU_DER_VEC4 = + 'return select(a * (b + vec4(1.0)), a, b >= vec4(0.));'; const EQUAL = 'return f32(a == b);'; const EQUAL_VEC4 = 'return vec4(a == b);'; const GREATER = 'return f32(a > b);'; diff --git a/tfjs-backend-webgpu/src/kernels/EluGrad.ts b/tfjs-backend-webgpu/src/kernels/EluGrad.ts index a9c72429ba9..53a1152e969 100644 --- a/tfjs-backend-webgpu/src/kernels/EluGrad.ts +++ b/tfjs-backend-webgpu/src/kernels/EluGrad.ts @@ -1,38 +1,38 @@ -/** - * @license - * Copyright 2022 Google LLC. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {EluGrad, EluGradInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; - -import {WebGPUBackend} from '../backend_webgpu'; -import {BinaryOpType} from '../binary_op_util'; -import {BinaryOpProgram} from '../binary_op_webgpu'; - -export const eluGrad = - (args: {inputs: EluGradInputs, backend: WebGPUBackend}): TensorInfo => { - const {inputs, backend} = args; - const {dy, y} = inputs; - - const program = - new BinaryOpProgram(BinaryOpType.ELU_DER, dy.shape, y.shape); - return backend.runWebGPUProgram(program, [dy, y], dy.dtype); - }; - -export const eluGradConfig: KernelConfig = { - kernelName: EluGrad, - backendName: 'webgpu', - kernelFunc: eluGrad as unknown as KernelFunc -}; +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {EluGrad, EluGradInputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {BinaryOpType} from '../binary_op_util'; +import {BinaryOpProgram} from '../binary_op_webgpu'; + +export const eluGrad = + (args: {inputs: EluGradInputs, backend: WebGPUBackend}): TensorInfo => { + const {inputs, backend} = args; + const {dy, y} = inputs; + + const program = + new BinaryOpProgram(BinaryOpType.ELU_DER, dy.shape, y.shape); + return backend.runWebGPUProgram(program, [dy, y], dy.dtype); + }; + +export const eluGradConfig: KernelConfig = { + kernelName: EluGrad, + backendName: 'webgpu', + kernelFunc: eluGrad as unknown as KernelFunc +}; From 1c404ba9d47bb6475f28f68b302c93af7a71bdf5 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Mon, 9 Jan 2023 10:53:27 +0800 Subject: [PATCH 3/3] Address Jiajia's comment --- tfjs-backend-webgpu/src/kernels/EluGrad.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/kernels/EluGrad.ts b/tfjs-backend-webgpu/src/kernels/EluGrad.ts index 53a1152e969..ed5b387b5a9 100644 --- a/tfjs-backend-webgpu/src/kernels/EluGrad.ts +++ b/tfjs-backend-webgpu/src/kernels/EluGrad.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2022 Google LLC. + * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at