From cee2c0429ef494daa1d97fbb22c0052637654743 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Mon, 28 Nov 2022 13:43:55 +0800 Subject: [PATCH] webgpu: support linSpace operator --- tfjs-backend-webgpu/src/kernels/LinSpace.ts | 39 +++++++++++++++ tfjs-backend-webgpu/src/lin_space_webgpu.ts | 50 +++++++++++++++++++ .../src/register_all_kernels.ts | 2 + tfjs-backend-webgpu/src/setup_test.ts | 1 - 4 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 tfjs-backend-webgpu/src/kernels/LinSpace.ts create mode 100644 tfjs-backend-webgpu/src/lin_space_webgpu.ts diff --git a/tfjs-backend-webgpu/src/kernels/LinSpace.ts b/tfjs-backend-webgpu/src/kernels/LinSpace.ts new file mode 100644 index 0000000000..15f4699a04 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/LinSpace.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, LinSpace, LinSpaceAttrs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {LinSpaceProgram} from '../lin_space_webgpu'; + +export function linSpace(args: {backend: WebGPUBackend, attrs: LinSpaceAttrs}): + TensorInfo { + const {backend, attrs} = args; + const {start, stop, num} = attrs; + const step = (stop - start) / (num - 1); + + const program = new LinSpaceProgram(num); + const uniformData = + [{type: 'float32', data: [start]}, {type: 'float32', data: [step]}]; + return backend.runWebGPUProgram(program, [], 'float32', uniformData); +} + +export const linSpaceConfig: KernelConfig = { + kernelName: LinSpace, + backendName: 'webgpu', + kernelFunc: linSpace as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/lin_space_webgpu.ts b/tfjs-backend-webgpu/src/lin_space_webgpu.ts new file mode 100644 index 0000000000..07668c2b43 --- /dev/null +++ b/tfjs-backend-webgpu/src/lin_space_webgpu.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class LinSpaceProgram implements WebGPUProgram { + variableNames: string[] = []; + outputShape: number[] = []; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + uniforms = 'start : f32, step : f32,'; + workgroupSize: [number, number, number] = [64, 1, 1]; + size = true; + + constructor(shape: number) { + this.outputShape = [shape]; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + + this.shaderKey = 'linSpace'; + } + + getUserCode(): string { + const userCode = ` + ${main('index')} { + if (index < uniforms.size) { + setOutputAtIndex(index, uniforms.start + f32(index) * uniforms.step); + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 913db77d04..c1aee4c42b 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -79,6 +79,7 @@ import {isNaNConfig} from './kernels/IsNaN'; import {leakyReluConfig} from './kernels/LeakyRelu'; import {lessConfig} from './kernels/Less'; import {lessEqualConfig} from './kernels/LessEqual'; +import {linSpaceConfig} from './kernels/LinSpace'; import {logConfig} from './kernels/Log'; import {log1pConfig} from './kernels/Log1p'; import {logicalAndConfig} from './kernels/LogicalAnd'; @@ -212,6 +213,7 @@ const kernelConfigs: KernelConfig[] = [ leakyReluConfig, lessConfig, lessEqualConfig, + linSpaceConfig, log1pConfig, logConfig, logicalAndConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 8905a2c954..aa915e4b17 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -252,7 +252,6 @@ const TEST_FILTERS: TestFilter[] = [ 'diag ', 'dilation2d ', 'encodeWeights ', - 'linspace ', 'localResponseNormalization ', 'maxPool3d ', 'maxPool3dBackprop ',