From 3510998a9b48282cdf6a66db0a5a8b7afcd1125b Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 10 Jan 2020 10:13:53 +0800 Subject: [PATCH] [webgpu] Enable test for fusedConv2D FIX https://github.com/tensorflow/tfjs/issues/2660 --- tfjs-backend-webgpu/src/backend_webgpu.ts | 9 ++++++++- .../src/kernels/conv2d_mm_webgpu.ts | 18 +++++++++++------- .../src/kernels/unary_op_webgpu.ts | 4 ++-- tfjs-backend-webgpu/src/setup_test.ts | 10 ++++++++++ 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 87e20812ca3..40564d3936a 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -807,7 +807,14 @@ export class WebGPUBackend extends KernelBackend { convInfo.strideHeight, convInfo.strideWidth ]; - return this.compileAndRun(program, [input, filter], output, dimensions); + const inputs: Tensor[] = [input, filter]; + if (hasBias) { + inputs.push(bias); + } + if (hasPreluActivationWeights) { + inputs.push(preluActivationWeights); + } + return this.compileAndRun(program, inputs, output, dimensions); } private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'): diff --git a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts index 9f6f6d0b93f..34e59169c81 100644 --- a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts @@ -58,9 +58,12 @@ export class Conv2DMMProgram implements WebGPUProgram { const tileAOuter = this.workGroupSize[1] * elementsPerThread[1]; const tileBOuter = this.workGroupSize[0] * elementsPerThread[0]; const tileInner = tileAOuter > tileBOuter ? tileAOuter : tileBOuter; - util.assert(tileInner % this.workGroupSize[0] === 0 && - tileInner % this.workGroupSize[1] === 0, - () => 'tileInner must be multiple of workgroupsize.x and workgroupsize.y'); + util.assert( + tileInner % this.workGroupSize[0] === 0 && + tileInner % this.workGroupSize[1] === 0, + () => + // tslint:disable-next-line: max-line-length + 'tileInner must be multiple of workgroupsize.x and workgroupsize.y'); const tileSizeA = [tileAOuter, tileInner]; const tileSizeB = [tileInner, tileBOuter]; const dimAOuter = this.outputShape[1] * this.outputShape[2]; @@ -72,8 +75,8 @@ export class Conv2DMMProgram implements WebGPUProgram { `x[getFlatIndex(coord, xShape)]` : `coordsInBounds(coord, xShape) ? x[getFlatIndex(coord, xShape)] : 0`; const fitB = tilesFitEvenlyIntoShape(tileSizeB, [dimInner, dimBOuter]); - const sampleB = - fitB ? `W[row * dimBOuter + col]` : + const sampleB = fitB ? + `W[row * dimBOuter + col]` : `coordsInBounds(ivec2(row, col), ivec2(dimInner, dimBOuter)) ? W[row * dimBOuter + col] : 0`; @@ -90,7 +93,7 @@ export class Conv2DMMProgram implements WebGPUProgram { }`; } else { activationSnippet = ` - float activation(float x) { + float activation(float a) { ${activation} } `; @@ -155,6 +158,7 @@ export class Conv2DMMProgram implements WebGPUProgram { mm_matMul(dimAOuter, dimInner, dimBOuter); } `; - this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}`; + this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}${ + addBiasSnippet}${activationSnippet}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts index a3c2cd7f6d3..326e5799dc1 100644 --- a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts @@ -23,8 +23,8 @@ import {WebGPUProgram} from './webgpu_program'; export const RELU = 'return max(a, 0.0);'; export const RELU6 = 'return (a < 0.0) ? 0.0 : min(6.0, a);'; -export const LINEAR = `return x;`; -export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; +export const LINEAR = `return a;`; +export const ELU = `return (a >= 0.0) ? a : (exp(a) - 1.0);`; export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; export const ABS = `return abs(a);`; diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index ed08a305af6..ba374db3145 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -126,6 +126,16 @@ const TEST_FILTERS: TestFilter[] = [ 'fused', // Not yet implemented. ] }, + { + include: 'fused conv2d', + excludes: [ + 'im2row with prelu', // Actual != expected. + 'pointwise with prelu', // Actual != expected. + 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet + // implemented + 'fused matmul with relu6', // step not yet implemented + ] + }, { include: 'fromPixels', excludes: [