diff --git a/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts b/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts new file mode 100644 index 00000000000..1e0075437e8 --- /dev/null +++ b/tfjs-backend-webgpu/src/broadcast_args_webgpu.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class BroadcastArgsProgram implements WebGPUProgram { + outputShape: number[] = []; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + variableNames = ['s0', 's1']; + uniforms = 's0Size : i32, s1Size : i32, '; + workgroupSize: [number, number, number] = [64, 1, 1]; + size = true; + + constructor(shape: number) { + this.outputShape = [shape]; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + + this.shaderKey = 'broadcastArgs'; + } + + getUserCode(): string { + const userCode = ` + ${main('index')} { + if (index < uniforms.size) { + var s0 = 1.0; + var s1 = 1.0; + let indexS0 = index - uniforms.size + uniforms.s0Size; + let indexS1 = index - uniforms.size + uniforms.s1Size; + if (indexS0 >= 0) { + s0 = getS0(indexS0); + } + if (indexS1 >= 0) { + s1 = getS1(indexS1); + } + + if (s0 == 1.0) { + setOutputAtIndex(index, s1); + } else if (s1 == 1.0) { + setOutputAtIndex(index, s0); + } else if (s0 != s1) { + setOutputAtIndex(index, uniforms.NAN); + } else { + setOutputAtIndex(index, s0); + } + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts new file mode 100644 index 00000000000..6340995c029 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/BroadcastArgs.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, BroadcastArgs, BroadcastArgsInputs, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {BroadcastArgsProgram} from '../broadcast_args_webgpu'; + +export function broadcastArgs(args: { + inputs: BroadcastArgsInputs, + backend: WebGPUBackend, +}): TensorInfo { + const {inputs, backend} = args; + const {s0, s1} = inputs; + + if (backend.shouldExecuteOnCPU([s0, s1])) { + const s0TensorInfo = backend.tensorMap.get(s0.dataId); + const s1TensorInfo = backend.tensorMap.get(s1.dataId); + const s0Vals = s0TensorInfo.values as TypedArray; + const s1Vals = s1TensorInfo.values as TypedArray; + const broadcastShape = backend_util.assertAndGetBroadcastShape( + Array.from(s0Vals), Array.from(s1Vals)); + return backend.makeTensorInfo( + [broadcastShape.length], 'int32', Int32Array.from(broadcastShape)); + } + + const s0Size = util.sizeFromShape(s0.shape); + const s1Size = util.sizeFromShape(s1.shape); + const outputSize = Math.max(s0Size, s1Size); + + const program = new BroadcastArgsProgram(outputSize); + const uniformData = + [{type: 'int32', data: [s0Size]}, {type: 'int32', data: [s1Size]}]; + return backend.runWebGPUProgram(program, [s0, s1], 'int32', uniformData); +} + +export const broadcastArgsConfig: KernelConfig = { + kernelName: BroadcastArgs, + backendName: 'webgpu', + kernelFunc: broadcastArgs as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 014fd6a2180..8b5df62b6c3 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -36,6 +36,7 @@ import {avgPoolGradConfig} from './kernels/AvgPoolGrad'; import {batchMatMulConfig} from './kernels/BatchMatMul'; import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND'; import {bincountConfig} from './kernels/Bincount'; +import {broadcastArgsConfig} from './kernels/BroadcastArgs'; import {castConfig} from './kernels/Cast'; import {ceilConfig} from './kernels/Ceil'; import {clipByValueConfig} from './kernels/ClipByValue'; @@ -183,6 +184,7 @@ const kernelConfigs: KernelConfig[] = [ batchMatMulConfig, batchToSpaceNDConfig, bincountConfig, + broadcastArgsConfig, castConfig, ceilConfig, clipByValueConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 4af0b7f1796..e36c9a107eb 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -240,7 +240,6 @@ const TEST_FILTERS: TestFilter[] = [ // Not implemented kernel list. 'avgPool3d ', 'avgPool3dBackprop ', - 'broadcastArgs ', 'conv2DBackpropFilter ', 'gradient with clones, input=2x2x1,d2=1,f=1,s=1,d=1,p=same', // Conv2DBackpropFilter 'conv1d gradients', // Conv2DBackpropFilter