diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 8fb175d3ee5..40c9b7fc3df 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -813,7 +813,14 @@ export class WebGPUBackend extends KernelBackend { convInfo.strideHeight, convInfo.strideWidth ]; - return this.compileAndRun(program, [input, filter], output, dimensions); + const inputs: Tensor[] = [input, filter]; + if (hasBias) { + inputs.push(bias); + } + if (hasPreluActivationWeights) { + inputs.push(preluActivationWeights); + } + return this.compileAndRun(program, inputs, output, dimensions); } private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'): diff --git a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts index ca870610254..a0c483d4fba 100644 --- a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts @@ -61,8 +61,9 @@ export class Conv2DMMProgram implements WebGPUProgram { util.assert( tileInner % this.workGroupSize[0] === 0 && tileInner % this.workGroupSize[1] === 0, - () => `tileInner must be multiple of workgroupsize.x and - workgroupsize.y`); + () => + // tslint:disable-next-line: max-line-length + 'tileInner must be multiple of workgroupsize.x and workgroupsize.y'); const tileSizeA = [tileAOuter, tileInner]; const tileSizeB = [tileInner, tileBOuter]; const dimAOuter = this.outputShape[1] * this.outputShape[2]; @@ -92,7 +93,7 @@ export class Conv2DMMProgram implements WebGPUProgram { }`; } else { activationSnippet = ` - float activation(float x) { + float activation(float a) { ${activation} } `; @@ -155,6 +156,7 @@ export class Conv2DMMProgram implements WebGPUProgram { mm_matMul(dimAOuter, dimInner, dimBOuter); } `; - this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}`; + this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}${ + addBiasSnippet}${activationSnippet}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts index a3c2cd7f6d3..326e5799dc1 100644 --- a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts @@ -23,8 +23,8 @@ import {WebGPUProgram} from './webgpu_program'; export const RELU = 'return max(a, 0.0);'; export const RELU6 = 'return (a < 0.0) ? 0.0 : min(6.0, a);'; -export const LINEAR = `return x;`; -export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; +export const LINEAR = `return a;`; +export const ELU = `return (a >= 0.0) ? a : (exp(a) - 1.0);`; export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; export const ABS = `return abs(a);`; diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index ed08a305af6..ba374db3145 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -126,6 +126,16 @@ const TEST_FILTERS: TestFilter[] = [ 'fused', // Not yet implemented. ] }, + { + include: 'fused conv2d', + excludes: [ + 'im2row with prelu', // Actual != expected. + 'pointwise with prelu', // Actual != expected. + 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet + // implemented + 'fused matmul with relu6', // step not yet implemented + ] + }, { include: 'fromPixels', excludes: [