From 23861b0ce7a768796c2f357637fcfea67a040cd8 Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Thu, 23 Jan 2020 17:05:49 +0800 Subject: [PATCH] [webgpu] Use conv2dByMatMul by default when filter size is 1x1 --- tfjs-backend-webgpu/src/backend_webgpu.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 87e20812ca3..80499ba0d46 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -694,11 +694,6 @@ export class WebGPUBackend extends KernelBackend { conv2d(x: Tensor4D, filter: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { - if (env().getBool('WEBGPU_CONV_SEPARATE_IM2COL_SHADER') && - x.shape[0] === 1) { - return this.conv2dWithIm2Col(x, filter, convInfo); - } - if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && @@ -707,6 +702,11 @@ export class WebGPUBackend extends KernelBackend { return this.conv2dByMatMul(x, filter, convInfo); } + if (env().getBool('WEBGPU_CONV_SEPARATE_IM2COL_SHADER') && + x.shape[0] === 1) { + return this.conv2dWithIm2Col(x, filter, convInfo); + } + const dataId = this.write(null /*values*/, convInfo.outShape, x.dtype); const output = engine().makeTensorFromDataId(dataId, convInfo.outShape, x.dtype, this);