From c27626a6a755be160907c4fab747f36ad6f933f9 Mon Sep 17 00:00:00 2001 From: Dave McDaniel Date: Mon, 8 Feb 2021 22:01:13 -0500 Subject: [PATCH] add test to sort-axis example to catch issue of swapped perm_i and i also replaced outer iteration with zip directly since it takes slices. --- examples/sort-axis.rs | 80 ++++++++++++++++++++++++++++++++++++++----- scripts/all-tests.sh | 1 + 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/examples/sort-axis.rs b/examples/sort-axis.rs index 09410e819..7ae67fb07 100644 --- a/examples/sort-axis.rs +++ b/examples/sort-axis.rs @@ -102,21 +102,27 @@ where assert_eq!(axis_len, perm.indices.len()); debug_assert!(perm.correct()); + if self.is_empty() { + return self; + } + let mut result = Array::uninit(self.dim()); unsafe { // logically move ownership of all elements from self into result // the result realizes this ownership at .assume_init() further down let mut moved_elements = 0; - for i in 0..axis_len { - let perm_i = perm.indices[i]; - Zip::from(result.index_axis_mut(axis, perm_i)) - .and(self.index_axis(axis, i)) - .for_each(|to, from| { - copy_nonoverlapping(from, to.as_mut_ptr(), 1); - moved_elements += 1; - }); - } + Zip::from(&perm.indices) + .and(result.axis_iter_mut(axis)) + .for_each(|&perm_i, result_pane| { + // possible improvement: use unchecked indexing for `index_axis` + Zip::from(result_pane) + .and(self.index_axis(axis, perm_i)) + .for_each(|to, from| { + copy_nonoverlapping(from, to.as_mut_ptr(), 1); + moved_elements += 1; + }); + }); debug_assert_eq!(result.len(), moved_elements); // panic-critical begin: we must not panic // forget moved array elements but not its vec @@ -129,6 +135,7 @@ where } } } + #[cfg(feature = "std")] fn main() { let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap(); @@ -143,5 +150,60 @@ fn main() { let c = strings.permute_axis(Axis(1), &perm); println!("{:?}", c); } + #[cfg(not(feature = "std"))] fn main() {} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_permute_axis() { + let a = array![ + [107998.96, 1.], + [107999.08, 2.], + [107999.20, 3.], + [108000.33, 4.], + [107999.45, 5.], + [107999.57, 6.], + [108010.69, 7.], + [107999.81, 8.], + [107999.94, 9.], + [75600.09, 10.], + [75600.21, 11.], + [75601.33, 12.], + [75600.45, 13.], + [75600.58, 14.], + [109000.70, 15.], + [75600.82, 16.], + [75600.94, 17.], + [75601.06, 18.], + ]; + + let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]); + let b = a.permute_axis(Axis(0), &perm); + assert_eq!( + b, + array![ + [75600.09, 10.], + [75600.21, 11.], + [75600.45, 13.], + [75600.58, 14.], + [75600.82, 16.], + [75600.94, 17.], + [75601.06, 18.], + [75601.33, 12.], + [107998.96, 1.], + [107999.08, 2.], + [107999.20, 3.], + [107999.45, 5.], + [107999.57, 6.], + [107999.81, 8.], + [107999.94, 9.], + [108000.33, 4.], + [108010.69, 7.], + [109000.70, 15.], + ] + ); + } +} diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index 22f1d1f94..61dfc56dc 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -16,5 +16,6 @@ cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbo cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose cargo test --manifest-path=serialization-tests/Cargo.toml --verbose cargo test --manifest-path=blas-tests/Cargo.toml --verbose +cargo test --examples CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose ([ "$CHANNEL" != "nightly" ] || cargo bench --no-run --verbose --features "$FEATURES")