Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ndarray: could not broadcast array from shape: [2] to: [1] in item.loss.backward() #680

Closed
L-M-Sherlock opened this issue Aug 23, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@L-M-Sherlock
Copy link

L-M-Sherlock commented Aug 23, 2023

Describe the bug

My model can be trained when the batch size is one. But a fatal error happened when I set batch_size = 2.

=== PANIC ===
A fatal error happened, you can check the experiment logs here => './tmp/fsrs/experiment.log'
=============
thread 'training::test' panicked at 'ndarray: could not broadcast array from shape: [2] to: [1]', /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1529:13
stack backtrace:
   0: rust_begin_unwind
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/std/src/panicking.rs:578:5
   1: core::panicking::panic_fmt
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/panicking.rs:67:14
   2: ndarray::ArrayBase<S,D>::broadcast_unwrap::broadcast_panic
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1529:13
   3: ndarray::ArrayBase<S,D>::broadcast_unwrap
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1538:21
   4: ndarray::impl_methods::<impl ndarray::ArrayBase<S,D>>::zip_mut_with
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/impl_methods.rs:2456:33
   5: ndarray::impl_methods::<impl ndarray::ArrayBase<S,D>>::assign
             at /Users/jarrettye/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/impl_methods.rs:2355:9
   6: burn_ndarray::ops::base::NdArrayOps<E>::slice_assign
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-ndarray/src/ops/base.rs:48:9
   7: burn_ndarray::ops::tensor::<impl burn_tensor::tensor::ops::tensor::TensorOps<burn_ndarray::backend::NdArrayBackend<E>> for burn_ndarray::backend::NdArrayBackend<E>>::slice_assign
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-ndarray/src/ops/tensor.rs:200:9
   8: <burn_autodiff::ops::tensor::<impl burn_tensor::tensor::ops::tensor::TensorOps<burn_autodiff::backend::ADBackendDecorator<B>> for burn_autodiff::backend::ADBackendDecorator<B>>::slice::Index<_> as burn_autodiff::ops::backward::Backward<B,_,1_usize>>::backward::{{closure}}
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/tensor.rs:636:21
   9: burn_autodiff::ops::backward::unary
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/backward.rs:78:20
  10: <burn_autodiff::ops::tensor::<impl burn_tensor::tensor::ops::tensor::TensorOps<burn_autodiff::backend::ADBackendDecorator<B>> for burn_autodiff::backend::ADBackendDecorator<B>>::slice::Index<_> as burn_autodiff::ops::backward::Backward<B,_,1_usize>>::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/tensor.rs:634:17
  11: <burn_autodiff::ops::base::OpsStep<B,T,SB,_,_> as burn_autodiff::graph::base::Step>::step
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/ops/base.rs:142:9
  12: burn_autodiff::graph::backward::execute_steps::{{closure}}::{{closure}}
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:35:61
  13: core::iter::traits::iterator::Iterator::for_each::call::{{closure}}
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:854:29
  14: core::iter::traits::iterator::Iterator::fold
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:2482:21
  15: core::iter::traits::iterator::Iterator::for_each
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:857:9
  16: burn_autodiff::graph::backward::execute_steps::{{closure}}
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:35:27
  17: core::iter::traits::iterator::Iterator::for_each::call::{{closure}}
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:854:29
  18: core::iter::traits::double_ended::DoubleEndedIterator::rfold
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/double_ended.rs:307:21
  19: <core::iter::adapters::rev::Rev<I> as core::iter::traits::iterator::Iterator>::fold
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/adapters/rev.rs:64:9
  20: core::iter::traits::iterator::Iterator::for_each
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/iter/traits/iterator.rs:857:9
  21: burn_autodiff::graph::backward::execute_steps
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:33:5
  22: burn_autodiff::graph::backward::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/graph/backward.rs:11:5
  23: <burn_autodiff::backend::ADBackendDecorator<B> as burn_tensor::tensor::backend::base::ADBackend>::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-autodiff/src/backend.rs:42:9
  24: burn_tensor::tensor::api::float::<impl burn_tensor::tensor::api::base::Tensor<B,_>>::backward
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-tensor/src/tensor/api/float.rs:287:9
  25: fsrs_optimizer_rs::training::<impl burn_train::learner::train_val::TrainStep<fsrs_optimizer_rs::dataset::FSRSBatch<B>,burn_train::learner::classification::ClassificationOutput<B>> for fsrs_optimizer_rs::model::Model<B>>::step
             at ./src/training.rs:56:32
  26: burn_train::learner::epoch::TrainEpoch<TI>::run
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-train/src/learner/epoch.rs:108:24
  27: burn_train::learner::train_val::<impl burn_train::learner::base::Learner<B,M,O,LR,TO,VO>>::fit
             at /Users/jarrettye/.cargo/git/checkouts/burn-acfbee6a141c1b41/8808ee2/burn-train/src/learner/train_val.rs:131:21
  28: fsrs_optimizer_rs::training::train
             at ./src/training.rs:131:29
  29: fsrs_optimizer_rs::training::test
             at ./src/training.rs:157:5
  30: fsrs_optimizer_rs::training::test::{{closure}}
             at ./src/training.rs:149:11
  31: core::ops::function::FnOnce::call_once
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/ops/function.rs:250:5
  32: core::ops::function::FnOnce::call_once
             at /rustc/90c541806f23a127002de5b4038be731ba1458ca/library/core/src/ops/function.rs:250:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.

To Reproduce

My model is complicated, so I haven't find a minimal script to reproduce the error.

For details, please see: open-spaced-repetition/fsrs-rs#16 (comment)

The code reproducing error: open-spaced-repetition/fsrs-rs@23b8772

@nathanielsimard nathanielsimard added the bug Something isn't working label Aug 23, 2023
@nathanielsimard
Copy link
Member

@L-M-Sherlock does it work with other backends?

@L-M-Sherlock
Copy link
Author

L-M-Sherlock commented Aug 23, 2023

@L-M-Sherlock does it work with other backends?

It works with burn_wgpu. I haven't tested burn_tch because building torch-sys costs too much time.

@L-M-Sherlock
Copy link
Author

I tested the tch backend. It was blocked by calculating loss.

@nathanielsimard
Copy link
Member

Fixed with #686

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants