Skip to content

Commit 795e771

Browse files
committed
webgpu: support BroadcastArgs kernel
1 parent f14a50e commit 795e771

File tree

4 files changed

+115
-1
lines changed

4 files changed

+115
-1
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
19+
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
20+
21+
export class BroadcastArgsProgram implements WebGPUProgram {
22+
outputShape: number[] = [];
23+
shaderKey: string;
24+
dispatchLayout: {x: number[]};
25+
dispatch: [number, number, number];
26+
variableNames = ['s0', 's1'];
27+
uniforms = 's0Length : i32, s1Length : i32, ';
28+
workgroupSize: [number, number, number] = [64, 1, 1];
29+
size = true;
30+
31+
constructor(shape: number) {
32+
this.outputShape = [shape];
33+
this.dispatchLayout = flatDispatchLayout(this.outputShape);
34+
this.dispatch = computeDispatch(
35+
this.dispatchLayout, this.outputShape, this.workgroupSize);
36+
37+
this.shaderKey = 'broadcastArgs';
38+
}
39+
40+
getUserCode(): string {
41+
const userCode = `
42+
${main('index')} {
43+
if (index < uniforms.size) {
44+
var s0 = 1.0;
45+
var s1 = 1.0;
46+
let indexS0 = index - uniforms.size + uniforms.s0Length;
47+
let indexS1 = index - uniforms.size + uniforms.s1Length;
48+
if (indexS0 >= 0) {
49+
s0 = getS0(indexS0);
50+
}
51+
if (indexS1 >= 0) {
52+
s1 = getS1(indexS1);
53+
}
54+
55+
if (s0 == 1.0) {
56+
setOutputAtIndex(index, s1);
57+
} else if (s1 == 1.0) {
58+
setOutputAtIndex(index, s0);
59+
} else if (s0 != s1) {
60+
setOutputAtIndex(index, uniforms.NAN);
61+
} else {
62+
setOutputAtIndex(index, s0);
63+
}
64+
}
65+
}
66+
`;
67+
return userCode;
68+
}
69+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {BroadcastArgs, BroadcastArgsInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';
19+
20+
import {WebGPUBackend} from '../backend_webgpu';
21+
import {BroadcastArgsProgram} from '../broadcast_args_webgpu';
22+
23+
export function broadcastArgs(args: {
24+
inputs: BroadcastArgsInputs,
25+
backend: WebGPUBackend,
26+
}): TensorInfo {
27+
const {inputs, backend} = args;
28+
const {s0, s1} = inputs;
29+
30+
const s0Size = util.sizeFromShape(s0.shape);
31+
const s1Size = util.sizeFromShape(s1.shape);
32+
const outputSize = Math.max(s0Size, s1Size);
33+
34+
const program = new BroadcastArgsProgram(outputSize);
35+
const uniformData =
36+
[{type: 'int32', data: [s0Size]}, {type: 'int32', data: [s1Size]}];
37+
return backend.runWebGPUProgram(program, [s0, s1], 'int32', uniformData);
38+
}
39+
40+
export const broadcastArgsConfig: KernelConfig = {
41+
kernelName: BroadcastArgs,
42+
backendName: 'webgpu',
43+
kernelFunc: broadcastArgs as unknown as KernelFunc
44+
};

tfjs-backend-webgpu/src/register_all_kernels.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import {avgPoolGradConfig} from './kernels/AvgPoolGrad';
3636
import {batchMatMulConfig} from './kernels/BatchMatMul';
3737
import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND';
3838
import {bincountConfig} from './kernels/Bincount';
39+
import {broadcastArgsConfig} from './kernels/BroadcastArgs';
3940
import {castConfig} from './kernels/Cast';
4041
import {ceilConfig} from './kernels/Ceil';
4142
import {clipByValueConfig} from './kernels/ClipByValue';
@@ -177,6 +178,7 @@ const kernelConfigs: KernelConfig[] = [
177178
batchMatMulConfig,
178179
batchToSpaceNDConfig,
179180
bincountConfig,
181+
broadcastArgsConfig,
180182
castConfig,
181183
ceilConfig,
182184
clipByValueConfig,

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ const TEST_FILTERS: TestFilter[] = [
247247
// Not implemented kernel list.
248248
'avgPool3d ',
249249
'avgPool3dBackprop ',
250-
'broadcastArgs ',
251250
'conv2DBackpropFilter ',
252251
'gradient with clones, input=2x2x1,d2=1,f=1,s=1,d=1,p=same', // Conv2DBackpropFilter
253252
'conv1d gradients', // Conv2DBackpropFilter

0 commit comments

Comments
 (0)