Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import './flags_webgpu';

import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBackend, Rank, RecursiveArray, ShapeMap, slice_util, Tensor, Tensor2D, Tensor3D, Tensor4D, TimingInfo, util} from '@tensorflow/tfjs-core';
import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile';
// TODO(xing.xu): use FusedConv2DConfig from backend_util:
// https://github.com/tensorflow/tfjs/issues/2471
// tslint:disable-next-line: no-imports-from-dist
import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util';

import {BufferManager} from './buffer_manager';
import {ArgMinMaxProgram} from './kernels/argminmax_webgpu';
Expand Down Expand Up @@ -747,6 +751,65 @@ export class WebGPUBackend extends KernelBackend {
return this.compileAndRun(program, [x, filter], null, dimensions);
}

mapActivationToShaderProgram(
activation: backend_util.Activation, packed = false): string {
if (activation === 'linear') {
return unary_op.LINEAR;
} else if (activation === 'relu') {
return unary_op.RELU;
} else if (activation === 'elu') {
return unary_op.ELU;
} else if (activation === 'relu6') {
return unary_op.RELU6;
} else if (activation === 'prelu') {
return binary_op.PRELU;
}
throw new Error(`Activation ${
activation} has not been implemented for the WebGL backend.`);
}

fusedConv2d(
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
const dataId = this.write(null /*values*/, convInfo.outShape, input.dtype);
const output = engine().makeTensorFromDataId(
dataId, convInfo.outShape, input.dtype, this);

const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const fusedActivation = activation ?
this.mapActivationToShaderProgram(activation, false) :
null;
let program: Conv2DMMProgram|Conv2DNaiveProgram;

const workPerThread = env().get('WEBGPU_CONV2D_WORK_PER_THREAD') as number;
if (workPerThread === -1) {
// TODO(kainino0x): This may be obsolete, but is kept for reference.
program = new Conv2DNaiveProgram(
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
} else {
program = new Conv2DMMProgram(
convInfo, workPerThread, hasBias, fusedActivation,
hasPreluActivationWeights);
}

const pad = convInfo.padInfo.type === 'VALID' ?
[0, 0] :
convInfo.padInfo.type === 'SAME' ?
[
-Math.floor((convInfo.filterShape[0] - 1) / 2),
-Math.floor((convInfo.filterShape[1] - 1) / 2)
] :
[convInfo.padInfo.top, convInfo.padInfo.left];

const dimensions = [
convInfo.filterHeight, convInfo.filterWidth, ...pad,
convInfo.strideHeight, convInfo.strideWidth
];

return this.compileAndRun(program, [input, filter], output, dimensions);
}

private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'):
Tensor {
const program = new ArgMinMaxProgram(x.shape, axis, reduceType);
Expand Down
34 changes: 33 additions & 1 deletion tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ export class Conv2DMMProgram implements WebGPUProgram {
uniforms = 'ivec2 filterDims, pad, stride, dilation;';
workGroupSize: [number, number, number];

constructor(convInfo: backend_util.Conv2DInfo, workPerThread: number) {
constructor(
convInfo: backend_util.Conv2DInfo, workPerThread: number, addBias = false,
activation: string = null, hasPreluActivationWeights = false) {
this.outputShape = convInfo.outShape;

util.assert(
Expand Down Expand Up @@ -75,7 +77,35 @@ export class Conv2DMMProgram implements WebGPUProgram {
this.dispatchLayout, this.outputShape, this.workGroupSize,
elementsPerThread);

let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivationWeights) {
activationSnippet = `float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
} else {
activationSnippet = `
float activation(float x) {
${activation}
}
`;
}

applyActivationSnippet = `value = activation(value);`;
}

const addBiasSnippet = addBias ? 'value += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}

if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
}

this.userCode = `
${activationSnippet}
${matMulSource}

int batch;
Expand Down Expand Up @@ -115,6 +145,8 @@ export class Conv2DMMProgram implements WebGPUProgram {
col % outShape[2],
row);
if (coordsInBounds(outCoord, outShape)) {
${addBiasSnippet}
${applyActivationSnippet}
result[getFlatIndex(outCoord, outShape)] = value;
}
}
Expand Down
33 changes: 32 additions & 1 deletion tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
uniforms = 'ivec2 filterDims, pad, stride, dilation;';
workGroupSize: [number, number, number] = [4, 8, 4];

constructor(convInfo: backend_util.Conv2DInfo) {
constructor(
convInfo: backend_util.Conv2DInfo, addBias = false,
activation: string = null, hasPreluActivationWeights = false) {
this.outputShape = convInfo.outShape;
this.dispatchLayout = {x: [2], y: [1], z: [0, 3]};
this.dispatch = computeDispatch(
Expand All @@ -40,8 +42,35 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
util.assert(
convInfo.dataFormat === 'channelsLast',
() => 'TODO: NCHW is unimplemented');
let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivationWeights) {
activationSnippet = `float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
} else {
activationSnippet = `
float activation(float x) {
${activation}
}
`;
}

applyActivationSnippet = `value = activation(value);`;
}

const addBiasSnippet = addBias ? 'value += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}

if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
}

this.userCode = `
${activationSnippet}
float readInp(int batch, int row, int col, int chan) {
ivec4 coord = ivec4(batch, row, col, chan);
return coordsInBounds(coord, xShape) ?
Expand All @@ -57,6 +86,8 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
void writeResult(int batch, int row, int col, int chan, float value) {
ivec4 coord = ivec4(batch, row, col, chan);
if (coordsInBounds(coord, outShape)) {
${addBiasSnippet}
${applyActivationSnippet}
setOutput(batch, row, col, chan, value);
}
}
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import {WebGPUProgram} from './webgpu_program';

export const RELU = 'return max(a, 0.0);';
export const RELU6 = 'return (a < 0.0) ? 0.0 : min(6.0, a);';
export const LINEAR = `return x;`;
export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;

export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;

Expand Down
16 changes: 15 additions & 1 deletion tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,21 @@ const TEST_FILTERS: TestFilter[] = [
]
},
{include: 'floor divide ', excludes: []},
{
include: 'fused',
excludes: [
'A x B', // fusedBatchMatMul not yet implemented.
'A x B with elu', // elu not yet implemented.
'A x B with elu and broadcasted bias', // elu not yet implemented.
'A x B with bias only', // fusedBatchMatMul not yet implemented.
'basic with elu', // elu not yet implemented.
'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet
// implemented.
'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', // conv2dDerInput
// not yet
// implemented.
]
},
{
include: 'maxPool',
excludes: [
Expand Down Expand Up @@ -257,7 +272,6 @@ const TEST_FILTERS: TestFilter[] = [
excludes: [
'NCHW', // Not yet implemented.
'gradient', // 'conv2dDerInput' not yet implemented
'fused', // Not yet implemented.
'conv2dTranspose', // DerInput is not Implemented.
]
},
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/backend_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export * from '../ops/axis_util';
export * from '../ops/broadcast_util';
export * from '../ops/concat_util';
export * from '../ops/conv_util';
export {Activation} from '../ops/fused_util';
export {Activation, FusedConv2DConfig} from '../ops/fused_util';

export {BackendValues, TypedArray, upcastType, PixelData} from '../types';
export {MemoryInfo, TimingInfo} from '../engine';
Expand Down