Skip to content

Commit

Permalink
webgpu: support linSpace operator
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed Nov 30, 2022
1 parent 1073979 commit cee2c04
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 1 deletion.
39 changes: 39 additions & 0 deletions tfjs-backend-webgpu/src/kernels/LinSpace.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* @license
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, KernelFunc, LinSpace, LinSpaceAttrs, TensorInfo} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {LinSpaceProgram} from '../lin_space_webgpu';

export function linSpace(args: {backend: WebGPUBackend, attrs: LinSpaceAttrs}):
TensorInfo {
const {backend, attrs} = args;
const {start, stop, num} = attrs;
const step = (stop - start) / (num - 1);

const program = new LinSpaceProgram(num);
const uniformData =
[{type: 'float32', data: [start]}, {type: 'float32', data: [step]}];
return backend.runWebGPUProgram(program, [], 'float32', uniformData);
}

export const linSpaceConfig: KernelConfig = {
kernelName: LinSpace,
backendName: 'webgpu',
kernelFunc: linSpace as unknown as KernelFunc
};
50 changes: 50 additions & 0 deletions tfjs-backend-webgpu/src/lin_space_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/**
* @license
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class LinSpaceProgram implements WebGPUProgram {
variableNames: string[] = [];
outputShape: number[] = [];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
uniforms = 'start : f32, step : f32,';
workgroupSize: [number, number, number] = [64, 1, 1];
size = true;

constructor(shape: number) {
this.outputShape = [shape];
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);

this.shaderKey = 'linSpace';
}

getUserCode(): string {
const userCode = `
${main('index')} {
if (index < uniforms.size) {
setOutputAtIndex(index, uniforms.start + f32(index) * uniforms.step);
}
}
`;
return userCode;
}
}
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import {isNaNConfig} from './kernels/IsNaN';
import {leakyReluConfig} from './kernels/LeakyRelu';
import {lessConfig} from './kernels/Less';
import {lessEqualConfig} from './kernels/LessEqual';
import {linSpaceConfig} from './kernels/LinSpace';
import {logConfig} from './kernels/Log';
import {log1pConfig} from './kernels/Log1p';
import {logicalAndConfig} from './kernels/LogicalAnd';
Expand Down Expand Up @@ -212,6 +213,7 @@ const kernelConfigs: KernelConfig[] = [
leakyReluConfig,
lessConfig,
lessEqualConfig,
linSpaceConfig,
log1pConfig,
logConfig,
logicalAndConfig,
Expand Down
1 change: 0 additions & 1 deletion tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ const TEST_FILTERS: TestFilter[] = [
'diag ',
'dilation2d ',
'encodeWeights ',
'linspace ',
'localResponseNormalization ',
'maxPool3d ',
'maxPool3dBackprop ',
Expand Down

0 comments on commit cee2c04

Please sign in to comment.