diff --git a/crates/burn-autodiff/src/tests/conv2d.rs b/crates/burn-autodiff/src/tests/conv2d.rs index ea4350bc76..f6aca2df44 100644 --- a/crates/burn-autodiff/src/tests/conv2d.rs +++ b/crates/burn-autodiff/src/tests/conv2d.rs @@ -629,6 +629,69 @@ mod tests { test.assert_grads(grads); } + #[test] + fn test_conv2d_groups_stride_2() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 4, + channels_out: 4, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 1, + padding_2: 1, + stride_1: 2, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + groups: 4, + height: 4, + width: 4, + }; + let device = Default::default(); + let grads = Grads { + x: TestTensor::from_floats( + [[ + [ + [4., 8., 4., 5.], + [8., 16., 8., 10.], + [4., 8., 4., 5.], + [7., 14., 7., 8.], + ], + [ + [13., 26., 13., 14.], + [26., 52., 26., 28.], + [13., 26., 13., 14.], + [16., 32., 16., 17.], + ], + [ + [22., 44., 22., 23.], + [44., 88., 44., 46.], + [22., 44., 22., 23.], + [25., 50., 25., 26.], + ], + [ + [31., 62., 31., 32.], + [62., 124., 62., 64.], + [31., 62., 31., 32.], + [34., 68., 34., 35.], + ], + ]], + &device, + ), + weight: TestTensor::from_floats( + [ + [[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]], + [[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]], + [[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]], + [[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]], + ], + &device, + ), + bias: TestTensor::from_floats([4., 4., 4., 4.], &device), + }; + test.assert_grads(grads); + } + #[test] fn test_conv2d_groups_different_channels() { let test = Conv2dTestCase { diff --git a/crates/burn-tensor/src/tensor/ops/modules/conv.rs b/crates/burn-tensor/src/tensor/ops/modules/conv.rs index 9f8f4ca22c..73694b73e6 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/conv.rs @@ -427,6 +427,20 @@ fn conv2d_weight_grad_groups( ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); + let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::float_shape(&weight_grad_tmp).dims; + + if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { + weight_grad_tmp = B::float_slice( + weight_grad_tmp, + [ + 0..increment_ci, + 0..increment_co, + 0..kernel_size_1, + 0..kernel_size_2, + ], + ); + } + weight_grad = B::float_slice_assign( weight_grad, [