diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index fbcc3fb0f3d..8fb175d3ee5 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -700,11 +700,6 @@ export class WebGPUBackend extends KernelBackend { conv2d(x: Tensor4D, filter: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { - if (env().getBool('WEBGPU_CONV_SEPARATE_IM2COL_SHADER') && - x.shape[0] === 1) { - return this.conv2dWithIm2Col(x, filter, convInfo); - } - if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && @@ -713,6 +708,11 @@ export class WebGPUBackend extends KernelBackend { return this.conv2dByMatMul(x, filter, convInfo); } + if (env().getBool('WEBGPU_CONV_SEPARATE_IM2COL_SHADER') && + x.shape[0] === 1) { + return this.conv2dWithIm2Col(x, filter, convInfo); + } + const dataId = this.write(null /*values*/, convInfo.outShape, x.dtype); const output = engine().makeTensorFromDataId(dataId, convInfo.outShape, x.dtype, this);