From ed8ad464562fb9063b8e340b40c671203bf4cb31 Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Thu, 26 Mar 2020 10:03:43 +0800 Subject: [PATCH] [webgpu] Fix conv2d padding issue --- tfjs-backend-webgpu/src/backend_webgpu.ts | 18 ++---------------- .../src/kernels/conv2d_mm_webgpu.ts | 4 ++-- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 40c9b7fc3df..5db5b5cd5a6 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -726,14 +726,7 @@ export class WebGPUBackend extends KernelBackend { program = new Conv2DMMProgram(convInfo, workPerThread); } - const pad = convInfo.padInfo.type === 'VALID' ? - [0, 0] : - convInfo.padInfo.type === 'SAME' ? - [ - -Math.floor((convInfo.filterShape[0] - 1) / 2), - -Math.floor((convInfo.filterShape[1] - 1) / 2) - ] : - [convInfo.padInfo.top, convInfo.padInfo.left]; + const pad = [convInfo.padInfo.top, convInfo.padInfo.left]; const dimensions = [ convInfo.filterHeight, convInfo.filterWidth, ...pad, @@ -799,14 +792,7 @@ export class WebGPUBackend extends KernelBackend { hasPreluActivationWeights); } - const pad = convInfo.padInfo.type === 'VALID' ? - [0, 0] : - convInfo.padInfo.type === 'SAME' ? - [ - -Math.floor((convInfo.filterShape[0] - 1) / 2), - -Math.floor((convInfo.filterShape[1] - 1) / 2) - ] : - [convInfo.padInfo.top, convInfo.padInfo.left]; + const pad = [convInfo.padInfo.top, convInfo.padInfo.left]; const dimensions = [ convInfo.filterHeight, convInfo.filterWidth, ...pad, diff --git a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts index a0c483d4fba..dd0e83cfd4e 100644 --- a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts @@ -129,8 +129,8 @@ export class Conv2DMMProgram implements WebGPUProgram { ivec4 coord = ivec4( batch, - pad[0] + outRow * stride[0] + dilation[0] * WRow, - pad[1] + outCol * stride[1] + dilation[1] * WCol, + outRow * stride[0] + dilation[0] * WRow - pad[0], + outCol * stride[1] + dilation[1] * WCol - pad[1], c % xShape[3]); return ${sampleA}; }