forked from rust-ndarray/ndarray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsort-axis.rs
140 lines (125 loc) · 3.63 KB
/
sort-axis.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use ndarray::prelude::*;
use ndarray::{Data, RemoveAxis, Zip};
use std::cmp::Ordering;
use std::ptr::copy_nonoverlapping;
// Type invariant: Each index appears exactly once
#[derive(Clone, Debug)]
pub struct Permutation {
indices: Vec<usize>,
}
impl Permutation {
/// Checks if the permutation is correct
pub fn from_indices(v: Vec<usize>) -> Result<Self, ()> {
let perm = Permutation { indices: v };
if perm.correct() {
Ok(perm)
} else {
Err(())
}
}
fn correct(&self) -> bool {
let axis_len = self.indices.len();
let mut seen = vec![false; axis_len];
for &i in &self.indices {
match seen.get_mut(i) {
None => return false,
Some(s) => {
if *s {
return false;
} else {
*s = true;
}
}
}
}
true
}
}
pub trait SortArray {
/// ***Panics*** if `axis` is out of bounds.
fn identity(&self, axis: Axis) -> Permutation;
fn sort_axis_by<F>(&self, axis: Axis, less_than: F) -> Permutation
where
F: FnMut(usize, usize) -> bool;
}
pub trait PermuteArray {
type Elem;
type Dim;
fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<Self::Elem, Self::Dim>
where
Self::Elem: Clone,
Self::Dim: RemoveAxis;
}
impl<A, S, D> SortArray for ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn identity(&self, axis: Axis) -> Permutation {
Permutation {
indices: (0..self.len_of(axis)).collect(),
}
}
fn sort_axis_by<F>(&self, axis: Axis, mut less_than: F) -> Permutation
where
F: FnMut(usize, usize) -> bool,
{
let mut perm = self.identity(axis);
perm.indices.sort_by(move |&a, &b| {
if less_than(a, b) {
Ordering::Less
} else if less_than(b, a) {
Ordering::Greater
} else {
Ordering::Equal
}
});
perm
}
}
impl<A, D> PermuteArray for Array<A, D>
where
D: Dimension,
{
type Elem = A;
type Dim = D;
fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<A, D>
where
D: RemoveAxis,
{
let axis = axis;
let axis_len = self.len_of(axis);
assert_eq!(axis_len, perm.indices.len());
debug_assert!(perm.correct());
let mut v = Vec::with_capacity(self.len());
let mut result;
// panic-critical begin: we must not panic
unsafe {
v.set_len(self.len());
result = Array::from_shape_vec_unchecked(self.dim(), v);
for i in 0..axis_len {
let perm_i = perm.indices[i];
Zip::from(result.index_axis_mut(axis, perm_i))
.and(self.index_axis(axis, i))
.apply(|to, from| copy_nonoverlapping(from, to, 1));
}
// forget moved array elements but not its vec
let mut old_storage = self.into_raw_vec();
old_storage.set_len(0);
// old_storage drops empty
}
// panic-critical end
result
}
}
fn main() {
let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
let strings = a.map(|x| x.to_string());
let perm = a.sort_axis_by(Axis(1), |i, j| a[[i, 0]] > a[[j, 0]]);
println!("{:?}", perm);
let b = a.permute_axis(Axis(0), &perm);
println!("{:?}", b);
println!("{:?}", strings);
let c = strings.permute_axis(Axis(1), &perm);
println!("{:?}", c);
}