Skip to content

Commit edea1e2

Browse files
authored
[webgpu] Use conv2dByMatMul by default when filter size is 1x1 (#2693)
PERF
1 parent 8b64deb commit edea1e2

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tfjs-backend-webgpu/src/backend_webgpu.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -700,11 +700,6 @@ export class WebGPUBackend extends KernelBackend {
700700

701701
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: backend_util.Conv2DInfo):
702702
Tensor4D {
703-
if (env().getBool('WEBGPU_CONV_SEPARATE_IM2COL_SHADER') &&
704-
x.shape[0] === 1) {
705-
return this.conv2dWithIm2Col(x, filter, convInfo);
706-
}
707-
708703
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
709704
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
710705
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
@@ -713,6 +708,11 @@ export class WebGPUBackend extends KernelBackend {
713708
return this.conv2dByMatMul(x, filter, convInfo);
714709
}
715710

711+
if (env().getBool('WEBGPU_CONV_SEPARATE_IM2COL_SHADER') &&
712+
x.shape[0] === 1) {
713+
return this.conv2dWithIm2Col(x, filter, convInfo);
714+
}
715+
716716
const dataId = this.write(null /*values*/, convInfo.outShape, x.dtype);
717717
const output =
718718
engine().makeTensorFromDataId(dataId, convInfo.outShape, x.dtype, this);

0 commit comments

Comments
 (0)