Skip to content

Commit

Permalink
Merge pull request #915 from jturner314/make-axisdescription-non-tuple
Browse files Browse the repository at this point in the history
Make AxisDescription a non-tuple struct
  • Loading branch information
bluss committed Feb 11, 2021
2 parents a6fe82f + c264ecf commit 13c1999
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 34 deletions.
14 changes: 7 additions & 7 deletions examples/axis_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ where
{
println!("Regularize:\n{:?}", a);
// reverse all neg axes
while let Some(ax) = a.axes().find(|ax| ax.stride() <= 0) {
if ax.stride() == 0 {
while let Some(ax) = a.axes().find(|ax| ax.stride <= 0) {
if ax.stride == 0 {
return Err(());
}
// reverse ax
println!("Reverse {:?}", ax.axis());
a.invert_axis(ax.axis());
println!("Reverse {:?}", ax.axis);
a.invert_axis(ax.axis);
}

// sort by least stride
let mut i = 0;
let n = a.ndim();
while let Some(ax) = a.axes().rev().skip(i).min_by_key(|ax| ax.stride().abs()) {
a.swap_axes(n - 1 - i, ax.axis().index());
println!("Swap {:?} <=> {}", ax.axis(), n - 1 - i);
while let Some(ax) = a.axes().rev().skip(i).min_by_key(|ax| ax.stride.abs()) {
a.swap_axes(n - 1 - i, ax.axis.index());
println!("Swap {:?} <=> {}", ax.axis, n - 1 - i);
i += 1;
}

Expand Down
41 changes: 26 additions & 15 deletions src/dimension/axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ pub struct Axes<'a, D> {

/// Description of the axis, its length and its stride.
#[derive(Debug)]
pub struct AxisDescription(pub Axis, pub Ix, pub Ixs);
pub struct AxisDescription {
pub axis: Axis,
pub len: usize,
pub stride: isize,
}

copy_and_clone!(AxisDescription);

Expand All @@ -51,19 +55,22 @@ copy_and_clone!(AxisDescription);
#[allow(clippy::len_without_is_empty)]
impl AxisDescription {
/// Return axis
#[deprecated(note = "Use .axis field instead", since = "0.15.0")]
#[inline(always)]
pub fn axis(self) -> Axis {
self.0
self.axis
}
/// Return length
#[deprecated(note = "Use .len field instead", since = "0.15.0")]
#[inline(always)]
pub fn len(self) -> Ix {
self.1
self.len
}
/// Return stride
#[deprecated(note = "Use .stride field instead", since = "0.15.0")]
#[inline(always)]
pub fn stride(self) -> Ixs {
self.2
self.stride
}
}

Expand All @@ -79,11 +86,11 @@ where
fn next(&mut self) -> Option<Self::Item> {
if self.start < self.end {
let i = self.start.post_inc();
Some(AxisDescription(
Axis(i),
self.dim[i],
self.strides[i] as Ixs,
))
Some(AxisDescription {
axis: Axis(i),
len: self.dim[i],
stride: self.strides[i] as Ixs,
})
} else {
None
}
Expand All @@ -94,7 +101,11 @@ where
F: FnMut(B, AxisDescription) -> B,
{
(self.start..self.end)
.map(move |i| AxisDescription(Axis(i), self.dim[i], self.strides[i] as isize))
.map(move |i| AxisDescription {
axis: Axis(i),
len: self.dim[i],
stride: self.strides[i] as isize,
})
.fold(init, f)
}

Expand All @@ -111,11 +122,11 @@ where
fn next_back(&mut self) -> Option<Self::Item> {
if self.start < self.end {
let i = self.end.pre_dec();
Some(AxisDescription(
Axis(i),
self.dim[i],
self.strides[i] as Ixs,
))
Some(AxisDescription {
axis: Axis(i),
len: self.dim[i],
stride: self.strides[i] as Ixs,
})
} else {
None
}
Expand Down
10 changes: 5 additions & 5 deletions src/dimension/dimension_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ pub trait Dimension:
};
axes_of(self, strides)
.rev()
.min_by_key(|ax| ax.stride().abs())
.map_or(Axis(n - 1), |ax| ax.axis())
.min_by_key(|ax| ax.stride.abs())
.map_or(Axis(n - 1), |ax| ax.axis)
}

/// Compute the maximum stride axis (absolute value), under the constraint
Expand All @@ -346,9 +346,9 @@ pub trait Dimension:
_ => {}
}
axes_of(self, strides)
.filter(|ax| ax.len() > 1)
.max_by_key(|ax| ax.stride().abs())
.map_or(Axis(0), |ax| ax.axis())
.filter(|ax| ax.len > 1)
.max_by_key(|ax| ax.stride.abs())
.map_or(Axis(0), |ax| ax.axis)
}

/// Convert the dimensional into a dynamic dimensional (IxDyn).
Expand Down
14 changes: 7 additions & 7 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,16 +557,16 @@ where
where
F: FnMut(AxisDescription) -> Slice,
{
(0..self.ndim()).for_each(|ax| {
for ax in 0..self.ndim() {
self.slice_axis_inplace(
Axis(ax),
f(AxisDescription(
Axis(ax),
self.dim[ax],
self.strides[ax] as isize,
)),
f(AxisDescription {
axis: Axis(ax),
len: self.dim[ax],
stride: self.strides[ax] as isize,
}),
)
})
}
}

/// Return a reference to the element at `index`, or return `None`
Expand Down

0 comments on commit 13c1999

Please sign in to comment.