Skip to content

Commit

Permalink
Fix conv2d_weight_grad_groups (#1891)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jun 17, 2024
1 parent a04da9a commit 8071b63
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
63 changes: 63 additions & 0 deletions crates/burn-autodiff/src/tests/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,69 @@ mod tests {
test.assert_grads(grads);
}

#[test]
fn test_conv2d_groups_stride_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 4,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 2,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[4., 8., 4., 5.],
[8., 16., 8., 10.],
[4., 8., 4., 5.],
[7., 14., 7., 8.],
],
[
[13., 26., 13., 14.],
[26., 52., 26., 28.],
[13., 26., 13., 14.],
[16., 32., 16., 17.],
],
[
[22., 44., 22., 23.],
[44., 88., 44., 46.],
[22., 44., 22., 23.],
[25., 50., 25., 26.],
],
[
[31., 62., 31., 32.],
[62., 124., 62., 64.],
[31., 62., 31., 32.],
[34., 68., 34., 35.],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]],
[[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]],
[[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]],
[[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]],
],
&device,
),
bias: TestTensor::from_floats([4., 4., 4., 4.], &device),
};
test.assert_grads(grads);
}

#[test]
fn test_conv2d_groups_different_channels() {
let test = Conv2dTestCase {
Expand Down
14 changes: 14 additions & 0 deletions crates/burn-tensor/src/tensor/ops/modules/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,20 @@ fn conv2d_weight_grad_groups<B: Backend>(
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = B::float_shape(&weight_grad_tmp).dims;

if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {
weight_grad_tmp = B::float_slice(
weight_grad_tmp,
[
0..increment_ci,
0..increment_co,
0..kernel_size_1,
0..kernel_size_2,
],
);
}

weight_grad = B::float_slice_assign(
weight_grad,
[
Expand Down

0 comments on commit 8071b63

Please sign in to comment.