From c2bac43c2126d90677525dffedf3479ebf7d1d30 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Thu, 5 Jan 2023 14:00:50 +0800 Subject: [PATCH 1/4] webgpu: support BroadcastArgs kernel --- .../src/broadcast_args_webgpu.ts | 69 +++++++++++++++++++ .../src/kernels/BroadcastArgs.ts | 44 ++++++++++++ .../src/register_all_kernels.ts | 2 + tfjs-backend-webgpu/src/setup_test.ts | 1 - 4 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 tfjs-backend-webgpu/src/broadcast_args_webgpu.ts create mode 100644 tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts diff --git a/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts b/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts new file mode 100644 index 00000000000..c91f5c4cc06 --- /dev/null +++ b/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class BroadcastArgsProgram implements WebGPUProgram { + outputShape: number[] = []; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + variableNames = ['s0', 's1']; + uniforms = 's0Length : i32, s1Length : i32, '; + workgroupSize: [number, number, number] = [64, 1, 1]; + size = true; + + constructor(shape: number) { + this.outputShape = [shape]; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + + this.shaderKey = 'broadcastArgs'; + } + + getUserCode(): string { + const userCode = ` + ${main('index')} { + if (index < uniforms.size) { + var s0 = 1.0; + var s1 = 1.0; + let indexS0 = index - uniforms.size + uniforms.s0Length; + let indexS1 = index - uniforms.size + uniforms.s1Length; + if (indexS0 >= 0) { + s0 = getS0(indexS0); + } + if (indexS1 >= 0) { + s1 = getS1(indexS1); + } + + if (s0 == 1.0) { + setOutputAtIndex(index, s1); + } else if (s1 == 1.0) { + setOutputAtIndex(index, s0); + } else if (s0 != s1) { + setOutputAtIndex(index, uniforms.NAN); + } else { + setOutputAtIndex(index, s0); + } + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts new file mode 100644 index 00000000000..79a8c240716 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {BroadcastArgs, BroadcastArgsInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {BroadcastArgsProgram} from '../broadcast_args_webgpu'; + +export function broadcastArgs(args: { + inputs: BroadcastArgsInputs, + backend: WebGPUBackend, +}): TensorInfo { + const {inputs, backend} = args; + const {s0, s1} = inputs; + + const s0Size = util.sizeFromShape(s0.shape); + const s1Size = util.sizeFromShape(s1.shape); + const outputSize = Math.max(s0Size, s1Size); + + const program = new BroadcastArgsProgram(outputSize); + const uniformData = + [{type: 'int32', data: [s0Size]}, {type: 'int32', data: [s1Size]}]; + return backend.runWebGPUProgram(program, [s0, s1], 'int32', uniformData); +} + +export const broadcastArgsConfig: KernelConfig = { + kernelName: BroadcastArgs, + backendName: 'webgpu', + kernelFunc: broadcastArgs as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 7071f0c997f..ecaed0a109b 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -36,6 +36,7 @@ import {avgPoolGradConfig} from './kernels/AvgPoolGrad'; import {batchMatMulConfig} from './kernels/BatchMatMul'; import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND'; import {bincountConfig} from './kernels/Bincount'; +import {broadcastArgsConfig} from './kernels/BroadcastArgs'; import {castConfig} from './kernels/Cast'; import {ceilConfig} from './kernels/Ceil'; import {clipByValueConfig} from './kernels/ClipByValue'; @@ -182,6 +183,7 @@ const kernelConfigs: KernelConfig[] = [ batchMatMulConfig, batchToSpaceNDConfig, bincountConfig, + broadcastArgsConfig, castConfig, ceilConfig, clipByValueConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 55b6a7386a8..7506c043191 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -240,7 +240,6 @@ const TEST_FILTERS: TestFilter[] = [ // Not implemented kernel list. 'avgPool3d ', 'avgPool3dBackprop ', - 'broadcastArgs ', 'conv2DBackpropFilter ', 'gradient with clones, input=2x2x1,d2=1,f=1,s=1,d=1,p=same', // Conv2DBackpropFilter 'conv1d gradients', // Conv2DBackpropFilter From 53d120c6a5fbfab85f203bdb0cf269585a5391dc Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Thu, 5 Jan 2023 16:29:42 +0800 Subject: [PATCH 2/4] Address Yang's comment --- tfjs-backend-webgpu/src/broadcast_args_webgpu.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts b/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts index c91f5c4cc06..1e0075437e8 100644 --- a/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts +++ b/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts @@ -24,7 +24,7 @@ export class BroadcastArgsProgram implements WebGPUProgram { dispatchLayout: {x: number[]}; dispatch: [number, number, number]; variableNames = ['s0', 's1']; - uniforms = 's0Length : i32, s1Length : i32, '; + uniforms = 's0Size : i32, s1Size : i32, '; workgroupSize: [number, number, number] = [64, 1, 1]; size = true; @@ -43,8 +43,8 @@ export class BroadcastArgsProgram implements WebGPUProgram { if (index < uniforms.size) { var s0 = 1.0; var s1 = 1.0; - let indexS0 = index - uniforms.size + uniforms.s0Length; - let indexS1 = index - uniforms.size + uniforms.s1Length; + let indexS0 = index - uniforms.size + uniforms.s0Size; + let indexS1 = index - uniforms.size + uniforms.s1Size; if (indexS0 >= 0) { s0 = getS0(indexS0); } From 54a4f4bd16976114d25256d280d8c2389ab375cb Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Thu, 12 Jan 2023 14:54:18 +0800 Subject: [PATCH 3/4] Add a cpu path for this operator --- tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts index 79a8c240716..c1ad3b34039 100644 --- a/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts +++ b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {BroadcastArgs, BroadcastArgsInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, BroadcastArgs, BroadcastArgsInputs, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from '../backend_webgpu'; import {BroadcastArgsProgram} from '../broadcast_args_webgpu'; @@ -27,6 +27,17 @@ export function broadcastArgs(args: { const {inputs, backend} = args; const {s0, s1} = inputs; + if (backend.shouldExecuteOnCPU([s0, s1])) { + const s0BufferInfo = backend.tensorMap.get(s0.dataId); + const s1BufferInfo = backend.tensorMap.get(s1.dataId); + const s0Vals = s0BufferInfo.values as TypedArray; + const s1Vals = s1BufferInfo.values as TypedArray; + const broadcastShape = backend_util.assertAndGetBroadcastShape( + Array.from(s0Vals), Array.from(s1Vals)); + return backend.makeTensorInfo( + [broadcastShape.length], 'int32', Int32Array.from(broadcastShape)); + } + const s0Size = util.sizeFromShape(s0.shape); const s1Size = util.sizeFromShape(s1.shape); const outputSize = Math.max(s0Size, s1Size); From 38b8d17df1e39f08b924df307bc774dde037de33 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Mon, 16 Jan 2023 15:07:01 +0800 Subject: [PATCH 4/4] Address Jiajia's comments --- tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts index c1ad3b34039..6340995c029 100644 --- a/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts +++ b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts @@ -28,10 +28,10 @@ export function broadcastArgs(args: { const {s0, s1} = inputs; if (backend.shouldExecuteOnCPU([s0, s1])) { - const s0BufferInfo = backend.tensorMap.get(s0.dataId); - const s1BufferInfo = backend.tensorMap.get(s1.dataId); - const s0Vals = s0BufferInfo.values as TypedArray; - const s1Vals = s1BufferInfo.values as TypedArray; + const s0TensorInfo = backend.tensorMap.get(s0.dataId); + const s1TensorInfo = backend.tensorMap.get(s1.dataId); + const s0Vals = s0TensorInfo.values as TypedArray; + const s1Vals = s1TensorInfo.values as TypedArray; const broadcastShape = backend_util.assertAndGetBroadcastShape( Array.from(s0Vals), Array.from(s1Vals)); return backend.makeTensorInfo(