Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions examples/sort-axis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,27 @@ where
assert_eq!(axis_len, perm.indices.len());
debug_assert!(perm.correct());

if self.is_empty() {
return self;
}

let mut result = Array::uninit(self.dim());

unsafe {
// logically move ownership of all elements from self into result
// the result realizes this ownership at .assume_init() further down
let mut moved_elements = 0;
for i in 0..axis_len {
let perm_i = perm.indices[i];
Zip::from(result.index_axis_mut(axis, perm_i))
.and(self.index_axis(axis, i))
.for_each(|to, from| {
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
moved_elements += 1;
});
}
Zip::from(&perm.indices)
Copy link
Member

@bluss bluss Feb 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tip, we can use slices just fine in Zip.

All that was needed for the fix was to swap perm_i and i, but I saw this change as an improvement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip, I haven't used Zip much but should start. I noticed this change being beyond the index swap.

.and(result.axis_iter_mut(axis))
.for_each(|&perm_i, result_pane| {
// possible improvement: use unchecked indexing for `index_axis`
Zip::from(result_pane)
.and(self.index_axis(axis, perm_i))
.for_each(|to, from| {
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
moved_elements += 1;
});
});
debug_assert_eq!(result.len(), moved_elements);
// panic-critical begin: we must not panic
// forget moved array elements but not its vec
Expand All @@ -129,6 +135,7 @@ where
}
}
}

#[cfg(feature = "std")]
fn main() {
let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
Expand All @@ -143,5 +150,60 @@ fn main() {
let c = strings.permute_axis(Axis(1), &perm);
println!("{:?}", c);
}

#[cfg(not(feature = "std"))]
fn main() {}

#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permute_axis() {
let a = array![
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[108000.33, 4.],
[107999.45, 5.],
[107999.57, 6.],
[108010.69, 7.],
[107999.81, 8.],
[107999.94, 9.],
[75600.09, 10.],
[75600.21, 11.],
[75601.33, 12.],
[75600.45, 13.],
[75600.58, 14.],
[109000.70, 15.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
];

let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);
let b = a.permute_axis(Axis(0), &perm);
assert_eq!(
b,
array![
[75600.09, 10.],
[75600.21, 11.],
[75600.45, 13.],
[75600.58, 14.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
[75601.33, 12.],
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[107999.45, 5.],
[107999.57, 6.],
[107999.81, 8.],
[107999.94, 9.],
[108000.33, 4.],
[108010.69, 7.],
[109000.70, 15.],
]
);
}
}
1 change: 1 addition & 0 deletions scripts/all-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbo
cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose
cargo test --manifest-path=serialization-tests/Cargo.toml --verbose
cargo test --manifest-path=blas-tests/Cargo.toml --verbose
cargo test --examples
CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose
([ "$CHANNEL" != "nightly" ] || cargo bench --no-run --verbose --features "$FEATURES")