diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index fedf6fa82ba..eff9036ea1f 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -21,6 +21,10 @@ import './flags_webgpu'; import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBackend, Rank, RecursiveArray, ShapeMap, slice_util, Tensor, Tensor2D, Tensor3D, Tensor4D, TimingInfo, util} from '@tensorflow/tfjs-core'; import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; +// TODO(xing.xu): use FusedConv2DConfig from backend_util: +// https://github.com/tensorflow/tfjs/issues/2471 +// tslint:disable-next-line: no-imports-from-dist +import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util'; import {BufferManager} from './buffer_manager'; import {ArgMinMaxProgram} from './kernels/argminmax_webgpu'; @@ -747,6 +751,65 @@ export class WebGPUBackend extends KernelBackend { return this.compileAndRun(program, [x, filter], null, dimensions); } + mapActivationToShaderProgram( + activation: backend_util.Activation, packed = false): string { + if (activation === 'linear') { + return unary_op.LINEAR; + } else if (activation === 'relu') { + return unary_op.RELU; + } else if (activation === 'elu') { + return unary_op.ELU; + } else if (activation === 'relu6') { + return unary_op.RELU6; + } else if (activation === 'prelu') { + return binary_op.PRELU; + } + throw new Error(`Activation ${ + activation} has not been implemented for the WebGL backend.`); + } + + fusedConv2d( + {input, filter, convInfo, bias, activation, preluActivationWeights}: + FusedConv2DConfig): Tensor4D { + const dataId = this.write(null /*values*/, convInfo.outShape, input.dtype); + const output = engine().makeTensorFromDataId( + dataId, convInfo.outShape, input.dtype, this); + + const hasBias = bias != null; + const hasPreluActivationWeights = preluActivationWeights != null; + const fusedActivation = activation ? + this.mapActivationToShaderProgram(activation, false) : + null; + let program: Conv2DMMProgram|Conv2DNaiveProgram; + + const workPerThread = env().get('WEBGPU_CONV2D_WORK_PER_THREAD') as number; + if (workPerThread === -1) { + // TODO(kainino0x): This may be obsolete, but is kept for reference. + program = new Conv2DNaiveProgram( + convInfo, hasBias, fusedActivation, hasPreluActivationWeights); + } else { + program = new Conv2DMMProgram( + convInfo, workPerThread, hasBias, fusedActivation, + hasPreluActivationWeights); + } + + const pad = convInfo.padInfo.type === 'VALID' ? + [0, 0] : + convInfo.padInfo.type === 'SAME' ? + [ + -Math.floor((convInfo.filterShape[0] - 1) / 2), + -Math.floor((convInfo.filterShape[1] - 1) / 2) + ] : + [convInfo.padInfo.top, convInfo.padInfo.left]; + + const dimensions = [ + convInfo.filterHeight, convInfo.filterWidth, ...pad, + convInfo.strideHeight, convInfo.strideWidth + ]; + + return this.compileAndRun(program, [input, filter], output, dimensions); + } + private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'): Tensor { const program = new ArgMinMaxProgram(x.shape, axis, reduceType); diff --git a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts index 673a337d532..0bf263aaf14 100644 --- a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts @@ -33,7 +33,9 @@ export class Conv2DMMProgram implements WebGPUProgram { uniforms = 'ivec2 filterDims, pad, stride, dilation;'; workGroupSize: [number, number, number]; - constructor(convInfo: backend_util.Conv2DInfo, workPerThread: number) { + constructor( + convInfo: backend_util.Conv2DInfo, workPerThread: number, addBias = false, + activation: string = null, hasPreluActivationWeights = false) { this.outputShape = convInfo.outShape; util.assert( @@ -75,7 +77,35 @@ export class Conv2DMMProgram implements WebGPUProgram { this.dispatchLayout, this.outputShape, this.workGroupSize, elementsPerThread); + let activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivationWeights) { + activationSnippet = `float activation(float a) { + float b = getPreluActivationWeightsAtOutCoords(); + ${activation} + }`; + } else { + activationSnippet = ` + float activation(float x) { + ${activation} + } + `; + } + + applyActivationSnippet = `value = activation(value);`; + } + + const addBiasSnippet = addBias ? 'value += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + + if (hasPreluActivationWeights) { + this.variableNames.push('preluActivationWeights'); + } + this.userCode = ` + ${activationSnippet} ${matMulSource} int batch; @@ -115,6 +145,8 @@ export class Conv2DMMProgram implements WebGPUProgram { col % outShape[2], row); if (coordsInBounds(outCoord, outShape)) { + ${addBiasSnippet} + ${applyActivationSnippet} result[getFlatIndex(outCoord, outShape)] = value; } } diff --git a/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts b/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts index 0d8899ae73f..6da178256ea 100644 --- a/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts @@ -31,7 +31,9 @@ export class Conv2DNaiveProgram implements WebGPUProgram { uniforms = 'ivec2 filterDims, pad, stride, dilation;'; workGroupSize: [number, number, number] = [4, 8, 4]; - constructor(convInfo: backend_util.Conv2DInfo) { + constructor( + convInfo: backend_util.Conv2DInfo, addBias = false, + activation: string = null, hasPreluActivationWeights = false) { this.outputShape = convInfo.outShape; this.dispatchLayout = {x: [2], y: [1], z: [0, 3]}; this.dispatch = computeDispatch( @@ -40,8 +42,35 @@ export class Conv2DNaiveProgram implements WebGPUProgram { util.assert( convInfo.dataFormat === 'channelsLast', () => 'TODO: NCHW is unimplemented'); + let activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivationWeights) { + activationSnippet = `float activation(float a) { + float b = getPreluActivationWeightsAtOutCoords(); + ${activation} + }`; + } else { + activationSnippet = ` + float activation(float x) { + ${activation} + } + `; + } + + applyActivationSnippet = `value = activation(value);`; + } + + const addBiasSnippet = addBias ? 'value += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + + if (hasPreluActivationWeights) { + this.variableNames.push('preluActivationWeights'); + } this.userCode = ` + ${activationSnippet} float readInp(int batch, int row, int col, int chan) { ivec4 coord = ivec4(batch, row, col, chan); return coordsInBounds(coord, xShape) ? @@ -57,6 +86,8 @@ export class Conv2DNaiveProgram implements WebGPUProgram { void writeResult(int batch, int row, int col, int chan, float value) { ivec4 coord = ivec4(batch, row, col, chan); if (coordsInBounds(coord, outShape)) { + ${addBiasSnippet} + ${applyActivationSnippet} setOutput(batch, row, col, chan, value); } } diff --git a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts index 32197797cf7..90c0fe3d3a4 100644 --- a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts @@ -23,6 +23,8 @@ import {WebGPUProgram} from './webgpu_program'; export const RELU = 'return max(a, 0.0);'; export const RELU6 = 'return (a < 0.0) ? 0.0 : min(6.0, a);'; +export const LINEAR = `return x;`; +export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 552108c3611..e47fd960f7f 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -181,6 +181,21 @@ const TEST_FILTERS: TestFilter[] = [ ] }, {include: 'floor divide ', excludes: []}, + { + include: 'fused', + excludes: [ + 'A x B', // fusedBatchMatMul not yet implemented. + 'A x B with elu', // elu not yet implemented. + 'A x B with elu and broadcasted bias', // elu not yet implemented. + 'A x B with bias only', // fusedBatchMatMul not yet implemented. + 'basic with elu', // elu not yet implemented. + 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet + // implemented. + 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', // conv2dDerInput + // not yet + // implemented. + ] + }, { include: 'maxPool', excludes: [ @@ -257,7 +272,6 @@ const TEST_FILTERS: TestFilter[] = [ excludes: [ 'NCHW', // Not yet implemented. 'gradient', // 'conv2dDerInput' not yet implemented - 'fused', // Not yet implemented. 'conv2dTranspose', // DerInput is not Implemented. ] }, diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index be150e53e7c..7bf8f9eed64 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -29,7 +29,7 @@ export * from '../ops/axis_util'; export * from '../ops/broadcast_util'; export * from '../ops/concat_util'; export * from '../ops/conv_util'; -export {Activation} from '../ops/fused_util'; +export {Activation, FusedConv2DConfig} from '../ops/fused_util'; export {BackendValues, TypedArray, upcastType, PixelData} from '../types'; export {MemoryInfo, TimingInfo} from '../engine';