Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong result when call reshape() after Tensor::stack #1053

Closed
wcshds opened this issue Dec 7, 2023 · 4 comments · Fixed by #1058
Closed

Wrong result when call reshape() after Tensor::stack #1053

wcshds opened this issue Dec 7, 2023 · 4 comments · Fixed by #1058
Labels
bug Something isn't working

Comments

@wcshds
Copy link
Contributor

wcshds commented Dec 7, 2023

Here is the code.

type Backend = NdArray;

let tensor = Tensor::<Backend, 1, Int>::arange(1..25).reshape([4, 6]);
let zeros: Tensor<Backend, 2, Int> = Tensor::zeros([4, 6]);
let intersperse =
    Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]);

println!("{}", intersperse);

The result is:

Tensor {
  data:
[[1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0],
 [7, 8, 9, 10, 11, 12, 0, 0, 0, 0, 0, 0],
 [13, 14, 15, 16, 17, 18, 0, 0, 0, 0, 0, 0],
 [19, 20, 21, 22, 23, 24, 0, 0, 0, 0, 0, 0]],
  shape:  [4, 12],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Int",
  dtype:  "i64",
}

But the expected result should be:

[[1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0],
 [7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0],
 [13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0],
 [19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0]],

I can get correct results on LibTorch Backend and Wgpu Backend.

@AuruTus
Copy link
Contributor

AuruTus commented Dec 8, 2023

Seems the bug of ndarray crate. The stride of new cat tensor is incorrect. Their bug code is in the new stride-creating function fortran_strides

The cat(...) called in the stack(...) in the above test code will finally call that fortran_strides(...) and it'll produce a tensor with stride {1, 4, 24, 0}, while the correct one should have a stride of {12, 2, 1, 0}. And this field is not shown in the fmt::Display impl, so it can only be found in debug mode.

print result of test code with cat
// after cat
Tensor {
  data:
[[[1, 0],
  [2, 0],
  [3, 0],
  [4, 0],
  [5, 0],
  [6, 0]],
 [[7, 0],
  [8, 0],
  [9, 0],
  [10, 0],
  [11, 0],
  [12, 0]],
 [[13, 0],
  [14, 0],
  [15, 0],
  [16, 0],
  [17, 0],
  [18, 0]],
 [[19, 0],
  [20, 0],
  [21, 0],
  [22, 0],
  [23, 0],
  [24, 0]]],
  shape:  [4, 6, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Int",
  dtype:  "i64",
}

// after calling reshape
Tensor {
  data:
[[1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0],
 [7, 8, 9, 10, 11, 12, 0, 0, 0, 0, 0, 0],
 [13, 14, 15, 16, 17, 18, 0, 0, 0, 0, 0, 0],
 [19, 20, 21, 22, 23, 24, 0, 0, 0, 0, 0, 0]],
  shape:  [4, 12],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Int",
  dtype:  "i64",
}
the result from dummy reshape code
// created from from_ints(...)
Tensor {
  data:
[[[1, 0],
  [2, 0],
  [3, 0],
  [4, 0],
  [5, 0],
  [6, 0]],
 [[7, 0],
  [8, 0],
  [9, 0],
  [10, 0],
  [11, 0],
  [12, 0]],
 [[13, 0],
  [14, 0],
  [15, 0],
  [16, 0],
  [17, 0],
  [18, 0]],
 [[19, 0],
  [20, 0],
  [21, 0],
  [22, 0],
  [23, 0],
  [24, 0]]],
  shape:  [4, 6, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Int",
  dtype:  "i64",
}

// after calling reshape
Tensor {
  data:
[[1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0],
 [7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0],
 [13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0],
 [19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0]],
  shape:  [4, 12],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Int",
  dtype:  "i64",
}
dummy reshape code
    #[test]
    fn test_bug_2() {
        use burn_tensor::{Int, Tensor};
        type Backend = crate::NdArray;

        let intersperse: Tensor<Backend, 3, Int> = Tensor::from_ints([
            [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],
            [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],
            [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],
            [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],
        ]);
        println!("{}", intersperse);
        let intersperse = intersperse.reshape([4, 12]);
        println!("{}", intersperse);
    }

@antimora antimora added the bug Something isn't working label Dec 8, 2023
@wcshds
Copy link
Contributor Author

wcshds commented Dec 10, 2023

@AuruTus Thanks for your investigation! However, this bug is not likely caused by fortran_strides(). It might be due to strange behavior of into_shape() in ndarray. rust-ndarray/ndarray#1309 (comment), rust-ndarray/ndarray#1310 (comment).

For burn, I believe using to_shape() instead of into_shape() for NdArray backend's reshape() method should resolve this issue.

@AuruTus
Copy link
Contributor

AuruTus commented Dec 11, 2023

@AuruTus Thanks for your investigation! However, this bug is not likely caused by fortran_strides(). It might be due to strange behavior of into_shape() in ndarray. rust-ndarray/ndarray#1309 (comment), rust-ndarray/ndarray#1310 (comment).

For burn, I believe using to_shape() instead of into_shape() for NdArray backend's reshape() method should resolve this issue.

Hi @wcshds . Thanks for the reply. Then I understand why this weird thing will happen.

The into_shape(...) will reserve the tensor's layout. And in their append(...) to create a new tensor, the column major layout is chosen when the axis is the outermost one, which is fortran_strides(...) 's work. And that's why to_shape(...) is recommended to use for reshaping tensors with a given layout.

Indeed, fortran_strides(...) itself has no bug. The default logic that they treat new empty tensor's layout maybe a bit controversial. I can investigate more later, for to_shape(...) has a different return type which cannot be directly used in the original place of into_shape(...) and it will hide that unexpected behavior.

@AuruTus
Copy link
Contributor

AuruTus commented Dec 11, 2023

I investigate the code again. And I find that actually we have a row major layout check in the macro reshape! . However, it checks the reversed axes too, which makes the column major layout also valid. And if that reversed check logic is removed, the reshape! can go to reshape(...) arm, and thus the correct shape can be gained.

I ran the cargo test under burn-ndarray and triggered a test action on my own repo, and all tests passed. I think it's ok to remove this reversed axes check logic if there's no specific need.

cc @antimora @nathanielsimard

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants