-
Notifications
You must be signed in to change notification settings - Fork 392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backward pass has mismatched dimensions #1877
Comments
are you using the version of burn from crates.io or the version on main? Is the model you used hosted somewhere? if not do you mind if I took a look at it? |
@skewballfox I'm using version 0.13.2 from crates.io and I've been using some torchvision models such as resnet18, efficientnet_v2_s and mobilenet_v2 each modified for binary classification (see below).
|
@galenoshea This is a run time error right? Does If it's not too much trouble, and you would rather not send us the onnx file. Could you try recreating with just one of the overloaded models? trying to narrow down the search a bit |
Cargo successfully builds and it is a runtime error. I just tried reproducing and found that Resnet18 works while Mobilenet and efficientnet have issues. I'm using images of size 224, but I've seen similar issues before when using awkward input sizes. Specifically, this happens when at a given layer when the an odd number of channels are trying to halve (Note the error from above, |
can you navigate to that specific location pointed to in the traceback ( |
I apologize I'm not sure how to get var info but these are the 2 functions that break when using the 2 backends. LibTorch Backend
NdArray Backend
|
You're good. I was hoping that traceback pointed to something in the generated model(so then we could figure out what step in burn-import needs work), but that's not the case here. |
Sounds like it's a bug in the backward pass of the Burn's OP. CCing @nathanielsimard , @louisfd , and @laggui, maybe the have some idea. |
Could you share one of the ONNX models so we can try to reproduce this issue? |
@laggui doesn't allow me to drop here, where do I share the model?
|
You could upload it somewhere (e.g., google drive) and share the link. Or, if it's a torchvision model you could share the script you used to generate the ONNX model with pytorch. /edit: ah right as pointed out below github supports zip format so you can zip the onnx file to upload it here. |
You need to zip it |
here's the zip and you'll find the func for creating the model above |
I can reproduce the issue with the provided onnx model on both ndarray and torch backends.
The issue happens in the conv2d backward with groups, specifically this line. |
Found the bug! Thanks a lot for filing the issue. Fixed with PR #1891. |
Describe the bug
I am trying to load an onnx model and get the gradient with respect to the input tensor. The forward pass works fine it breaks at the backward pass.
`
type B = Autodiff;
let device = ::Device::default();
`
NdArray Output:
y: Shape { dims: [1, 1] } thread 'main' panicked at /home/goshea/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/lib.rs:1529:13: ndarray: could not broadcast array from shape: [1, 1, 4, 4] to: [1, 1, 3, 3] note: run with
RUST_BACKTRACE=1environment variable to display a backtrace
LibTorch Output:
Finished
devprofile [unoptimized + debuginfo] target(s) in 9.08s Running
target/debug/burn_testy: Shape { dims: [1, 1] } thread 'main' panicked at /home/goshea/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.15.0/src/wrappers/tensor.rs:535:27: called
Result::unwrap()on an
Errvalue: Torch("The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 3\nException raised from infer_size_impl at ../aten/src/ATen/ExpandUtils.cpp:31 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7f45930e0a0c in /home/goshea/libtorch/lib/libc10.so)\nframe #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7f459308a8bc in /home/goshea/libtorch/lib/libc10.so)\nframe #2: at::infer_size_dimvector(c10::ArrayRef<long>, c10::ArrayRef<long>) + 0x3d4 (0x7f45949acfa4 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #3: at::TensorIteratorBase::compute_shape(at::TensorIteratorConfig const&) + 0xb8 (0x7f4594a5f3a8 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #4: at::TensorIteratorBase::build(at::TensorIteratorConfig&) + 0x6d (0x7f4594a602ad in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #5: <unknown function> + 0x1cf43ba (0x7f4594e353ba in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #6: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0x7a (0x7f4594e36daa in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #7: at::_ops::copy_::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x8f (0x7f4595bf600f in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #8: <unknown function> + 0x5d82a25 (0x7f4598ec3a25 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #9: at::_ops::copy_::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x8f (0x7f4595bf600f in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #10: <unknown function> + 0x5d84b14 (0x7f4598ec5b14 in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #11: at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool) + 0x16f (0x7f4595c6668f in /home/goshea/libtorch/lib/libtorch_cpu.so)\nframe #12: <unknown function> + 0x592cd2 (0x558b6da0ecd2 in target/debug/burn_test)\nframe #13: <unknown function> + 0x599f87 (0x558b6da15f87 in target/debug/burn_test)\nframe #14: <unknown function> + 0x57825c (0x558b6d9f425c in target/debug/burn_test)\nframe #15: <unknown function> + 0x578310 (0x558b6d9f4310 in target/debug/burn_test)\nframe #16: <unknown function> + 0x8fcfd (0x558b6d50bcfd in target/debug/burn_test)\nframe #17: <unknown function> + 0x27791d (0x558b6d6f391d in target/debug/burn_test)\nframe #18: <unknown function> + 0x144ea7 (0x558b6d5c0ea7 in target/debug/burn_test)\nframe #19: <unknown function> + 0x1440fe (0x558b6d5c00fe in target/debug/burn_test)\nframe #20: <unknown function> + 0x2757cd (0x558b6d6f17cd in target/debug/burn_test)\nframe #21: <unknown function> + 0x237075 (0x558b6d6b3075 in target/debug/burn_test)\nframe #22: <unknown function> + 0x1d0ff3 (0x558b6d64cff3 in target/debug/burn_test)\nframe #23: <unknown function> + 0x5e737d (0x558b6da6337d in target/debug/burn_test)\nframe #24: <unknown function> + 0x5d87cc (0x558b6da547cc in target/debug/burn_test)\nframe #25: <unknown function> + 0x5e84fb (0x558b6da644fb in target/debug/burn_test)\nframe #26: <unknown function> + 0x5e8578 (0x558b6da64578 in target/debug/burn_test)\nframe #27: <unknown function> + 0x5e7344 (0x558b6da63344 in target/debug/burn_test)\nframe #28: <unknown function> + 0x5d879a (0x558b6da5479a in target/debug/burn_test)\nframe #29: <unknown function> + 0x5e83b1 (0x558b6da643b1 in target/debug/burn_test)\nframe #30: <unknown function> + 0x5e1bc9 (0x558b6da5dbc9 in target/debug/burn_test)\nframe #31: <unknown function> + 0x5e0c98 (0x558b6da5cc98 in target/debug/burn_test)\nframe #32: <unknown function> + 0x5e7279 (0x558b6da63279 in target/debug/burn_test)\nframe #33: <unknown function> + 0x5e6b73 (0x558b6da62b73 in target/debug/burn_test)\nframe #34: <unknown function> + 0x2b3caf (0x558b6d72fcaf in target/debug/burn_test)\nframe #35: <unknown function> + 0x148936 (0x558b6d5c4936 in target/debug/burn_test)\nframe #36: <unknown function> + 0x2ab05c (0x558b6d72705c in target/debug/burn_test)\nframe #37: <unknown function> + 0x1aedc4 (0x558b6d62adc4 in target/debug/burn_test)\nframe #38: <unknown function> + 0x27bfdb (0x558b6d6f7fdb in target/debug/burn_test)\nframe #39: <unknown function> + 0x185e4e (0x558b6d601e4e in target/debug/burn_test)\nframe #40: <unknown function> + 0x181ce1 (0x558b6d5fdce1 in target/debug/burn_test)\nframe #41: <unknown function> + 0x6a3553 (0x558b6db1f553 in target/debug/burn_test)\nframe #42: <unknown function> + 0x181cba (0x558b6d5fdcba in target/debug/burn_test)\nframe #43: <unknown function> + 0x1aeeee (0x558b6d62aeee in target/debug/burn_test)\nframe #44: <unknown function> + 0x2a1ca (0x7f4592d0b1ca in /lib/x86_64-linux-gnu/libc.so.6)\nframe #45: __libc_start_main + 0x8b (0x7f4592d0b28b in /lib/x86_64-linux-gnu/libc.so.6)\nframe #46: <unknown function> + 0x7e705 (0x558b6d4fa705 in target/debug/burn_test)\n") note: run with
RUST_BACKTRACE=1environment variable to display a backtrace
The text was updated successfully, but these errors were encountered: