diff --git a/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts b/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts index 8231ccc521b..a16489b1f91 100644 --- a/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts +++ b/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts @@ -70,7 +70,7 @@ function fusedDepthwiseConv2d(args: { const convInfo = backend_util.computeConv2DInfo( (x as Tensor4D).shape, (filter as Tensor4D).shape, strides, dilations, - pad, dimRoundingMode); + pad, dimRoundingMode, true /* depthwise */); const fusedActivation = FusableActivation[activation as {} as keyof typeof FusableActivation]; diff --git a/tfjs-core/src/ops/fused_depthwise_conv2d_test.ts b/tfjs-core/src/ops/fused_depthwise_conv2d_test.ts index ada8531b9c8..752fe4d7449 100644 --- a/tfjs-core/src/ops/fused_depthwise_conv2d_test.ts +++ b/tfjs-core/src/ops/fused_depthwise_conv2d_test.ts @@ -69,6 +69,29 @@ describeWithFlags('fused depthwiseConv2D', ALL_ENVS, () => { expectArraysClose(await result.data(), expected); }); + it('basic with channel-wise broadcasted bias and relu', async () => { + const strides = 1; + const pad = 'same'; + const x = tf.tensor4d( + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8 + ], + [1, 3, 3, 4]); + const w = tf.tensor4d( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [2, 2, 4, 1]); + const bias = tf.tensor1d([0, 1, 2, 3]); + + const result = tf.fused.depthwiseConv2d({x, filter: w, strides, pad, bias}); + expect(result.shape).toEqual([1, 3, 3, 4]); + const expected = [ + 124, 167, 92, 142, 112, 117, 76, 124, 16, 28, 44, 64, + 88, 134, 134, 88, 76, 120, 154, 205, 40, 58, 80, 106, + 4, 18, 36, 31, 20, 33, 50, 71, 0, 7, 16, 27 + ]; + expectArraysClose(await result.data(), expected); + }); + it('basic with broadcasted bias and relu', async () => { const fSize = 2; const pad = 'valid';