Skip to content

Commit

Permalink
Fix select assign backward (#1739)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed May 7, 2024
1 parent bd06b38 commit a6e3b4e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
23 changes: 5 additions & 18 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1058,29 +1058,22 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}

impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectDimAssign<D> {
type State = (usize, IntTensor<B, 1>, Shape<D>, Shape<D>, B::Device);
type State = (usize, IntTensor<B, 1>);

fn backward(
self,
ops: Ops<Self::State, 2>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let (dim, indices, shape_lhs, shape_rhs, device) = ops.state;
let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices));
let (dim, indices) = ops.state;

binary::<B, D, D, D, _, _>(
ops.parents,
ops.node,
grads,
|grad| {
let zeros = B::float_zeros(shape_lhs, &device);
B::float_select_assign(grad, dim, indices_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::float_zeros(shape_rhs, &device);
B::float_select_assign(zeros, dim, indices_4rhs.unwrap(), grad)
},
|grad| grad,
|grad| B::float_select(grad, dim, indices),
);
}
}
Expand All @@ -1098,13 +1091,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indices.clone(),
B::float_shape(&tensor.primitive),
B::float_shape(&value.primitive),
B::float_device(&value.primitive),
),
(dim, indices.clone()),
B::float_select_assign(tensor.primitive, dim, indices, value.primitive),
),
OpsKind::UnTracked(prep) => prep.finish(B::float_select_assign(
Expand Down
19 changes: 19 additions & 0 deletions crates/burn-autodiff/src/tests/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,23 @@ mod tests {
Data::from([[64., 64., 64.], [19., 19., 19.]])
);
}

#[test]
fn test_select_assign_grad_different_shapes() {
let device = Default::default();

let indices: Tensor<TestAutodiffBackend, 1, Int> = Tensor::from_ints([1], &device);
let x: Tensor<TestAutodiffBackend, 2> = Tensor::ones([1, 1], &device).require_grad();
let y = Tensor::ones([2, 1], &device).require_grad();

let w = y.clone().select_assign(0, indices, x.clone());
let w = w.matmul(y.clone().transpose());

let grads = w.backward();
let x_grad = x.grad(&grads).unwrap();
let y_grad = y.grad(&grads).unwrap();

assert_eq!(x_grad.into_data(), Data::from([[2.0]]));
assert_eq!(y_grad.into_data(), Data::from([[5.0], [5.0]]));
}
}
14 changes: 4 additions & 10 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,12 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn select_assign<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indices_tensor: TchTensor<i64, 1>,
indices: TchTensor<i64, 1>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
let mut indices = Vec::with_capacity(D);
for _ in 0..D {
indices.push(None);
}
indices[dim] = Some(indices_tensor.tensor);

tensor.unary_ops(
|mut tensor| tensor.index_put_(&indices, &value.tensor, true),
|tensor| tensor.index_put(&indices, &value.tensor, true),
tensor.clone().unary_ops(
|mut tensor| tensor.index_add_(dim as i64, &indices.tensor, &value.tensor),
|tensor| tensor.index_add(dim as i64, &indices.tensor, &value.tensor),
)
}

Expand Down

0 comments on commit a6e3b4e

Please sign in to comment.