Skip to content

Commit

Permalink
Add multi_slice_* methods (supports flat tuples only) (#717)
Browse files Browse the repository at this point in the history
* Add multi_slice_* methods

* Limit MultiSlice impls to flat tuples

* Fix formatting

* Remove unnecessary dead_code annotation

* Fix typo in docs

* Forward multi_slice_move impl for owned tuples

* Avoid final clone in multi_slice_move impl

* Remove impl of MultiSlice for tuples of SliceInfo

* Fix docs for MultiSlice::multi_slice_move

* Add more multi_slice_mut tests
  • Loading branch information
jturner314 committed Oct 15, 2019
1 parent 792e17c commit 78846da
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,6 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
}

/// Returns `true` iff the slices intersect.
#[allow(dead_code)]
pub fn slices_intersect<D: Dimension>(
dim: &D,
indices1: &D::SliceArg,
Expand Down
34 changes: 34 additions & 0 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::iter::{
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
};
use crate::slice::MultiSlice;
use crate::stacking::stack;
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};

Expand Down Expand Up @@ -350,6 +351,39 @@ where
self.view_mut().slice_move(info)
}

/// Return multiple disjoint, sliced, mutable views of the array.
///
/// See [*Slicing*](#slicing) for full documentation.
/// See also [`SliceInfo`] and [`D::SliceArg`].
///
/// [`SliceInfo`]: struct.SliceInfo.html
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
///
/// **Panics** if any of the following occur:
///
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
/// * if an index is out of bounds or step size is zero
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
///
/// # Example
///
/// ```
/// use ndarray::{arr2, s};
///
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
/// let (mut edges, mut middle) = a.multi_slice_mut((s![.., ..;2], s![.., 1]));
/// edges.fill(1);
/// middle.fill(0);
/// assert_eq!(a, arr2(&[[1, 0, 1], [1, 0, 1]]));
/// ```
pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output
where
M: MultiSlice<'a, A, D>,
S: DataMut,
{
info.multi_slice_move(self.view_mut())
}

/// Slice the array, possibly changing the number of dimensions.
///
/// See [*Slicing*](#slicing) for full documentation.
Expand Down
26 changes: 26 additions & 0 deletions src/impl_views/splitting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// except according to those terms.

use crate::imp_prelude::*;
use crate::slice::MultiSlice;

/// Methods for read-only array views.
impl<'a, A, D> ArrayView<'a, A, D>
Expand Down Expand Up @@ -109,4 +110,29 @@ where
(left.deref_into_view_mut(), right.deref_into_view_mut())
}
}

/// Split the view into multiple disjoint slices.
///
/// This is similar to [`.multi_slice_mut()`], but `.multi_slice_move()`
/// consumes `self` and produces views with lifetimes matching that of
/// `self`.
///
/// See [*Slicing*](#slicing) for full documentation.
/// See also [`SliceInfo`] and [`D::SliceArg`].
///
/// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut
/// [`SliceInfo`]: struct.SliceInfo.html
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
///
/// **Panics** if any of the following occur:
///
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
/// * if an index is out of bounds or step size is zero
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
pub fn multi_slice_move<M>(self, info: M) -> M::Output
where
M: MultiSlice<'a, A, D>,
{
info.multi_slice_move(self)
}
}
21 changes: 21 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,13 @@ pub type Ixs = isize;
/// [`.slice_move()`]: #method.slice_move
/// [`.slice_collapse()`]: #method.slice_collapse
///
/// It's possible to take multiple simultaneous *mutable* slices with
/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only)
/// [`.multi_slice_move()`].
///
/// [`.multi_slice_mut()`]: #method.multi_slice_mut
/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move
///
/// ```
/// extern crate ndarray;
///
Expand Down Expand Up @@ -525,6 +532,20 @@ pub type Ixs = isize;
/// [12, 11, 10]]);
/// assert_eq!(f, g);
/// assert_eq!(f.shape(), &[2, 3]);
///
/// // Let's take two disjoint, mutable slices of a matrix with
/// //
/// // - One containing all the even-index columns in the matrix
/// // - One containing all the odd-index columns in the matrix
/// let mut h = arr2(&[[0, 1, 2, 3],
/// [4, 5, 6, 7]]);
/// let (s0, s1) = h.multi_slice_mut((s![.., ..;2], s![.., 1..;2]));
/// let i = arr2(&[[0, 2],
/// [4, 6]]);
/// let j = arr2(&[[1, 3],
/// [5, 7]]);
/// assert_eq!(s0, i);
/// assert_eq!(s1, j);
/// }
/// ```
///
Expand Down
103 changes: 102 additions & 1 deletion src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use crate::dimension::slices_intersect;
use crate::error::{ErrorKind, ShapeError};
use crate::Dimension;
use crate::{ArrayViewMut, Dimension};
use std::fmt;
use std::marker::PhantomData;
use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
Expand Down Expand Up @@ -629,3 +630,103 @@ macro_rules! s(
&*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*]
};
);

/// Slicing information describing multiple mutable, disjoint slices.
///
/// It's unfortunate that we need `'a` and `A` to be parameters of the trait,
/// but they're necessary until Rust supports generic associated types.
pub trait MultiSlice<'a, A, D>
where
A: 'a,
D: Dimension,
{
/// The type of the slices created by `.multi_slice_move()`.
type Output;

/// Split the view into multiple disjoint slices.
///
/// **Panics** if performing any individual slice panics or if the slices
/// are not disjoint (i.e. if they intersect).
fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output;
}

impl<'a, A, D> MultiSlice<'a, A, D> for ()
where
A: 'a,
D: Dimension,
{
type Output = ();

fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {}
}

impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo<D::SliceArg, Do0>,)
where
A: 'a,
D: Dimension,
Do0: Dimension,
{
type Output = (ArrayViewMut<'a, A, Do0>,);

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
(view.slice_move(self.0),)
}
}

macro_rules! impl_multislice_tuple {
([$($but_last:ident)*] $last:ident) => {
impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last);
};
(@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => {
impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo<D::SliceArg, $all>,)*)
where
A: 'a,
D: Dimension,
$($all: Dimension,)*
{
type Output = ($(ArrayViewMut<'a, A, $all>,)*);

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
#[allow(non_snake_case)]
let ($($all,)*) = self;

let shape = view.raw_dim();
assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($all,)*)));

let raw_view = view.into_raw_view_mut();
unsafe {
(
$(raw_view.clone().slice_move($but_last).deref_into_view_mut(),)*
raw_view.slice_move($last).deref_into_view_mut(),
)
}
}
}
};
(@intersects_self $shape:expr, ($head:expr,)) => {
false
};
(@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => {
$(slices_intersect($shape, $head, $tail)) ||*
|| impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*))
};
}

impl_multislice_tuple!([Do0] Do1);
impl_multislice_tuple!([Do0 Do1] Do2);
impl_multislice_tuple!([Do0 Do1 Do2] Do3);
impl_multislice_tuple!([Do0 Do1 Do2 Do3] Do4);
impl_multislice_tuple!([Do0 Do1 Do2 Do3 Do4] Do5);

impl<'a, A, D, T> MultiSlice<'a, A, D> for &T
where
A: 'a,
D: Dimension,
T: MultiSlice<'a, A, D>,
{
type Output = T::Output;

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
T::multi_slice_move(self, view)
}
}
90 changes: 90 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ use ndarray::{arr3, rcarr2};
use ndarray::{Slice, SliceInfo, SliceOrIndex};
use std::iter::FromIterator;

macro_rules! assert_panics {
($body:expr) => {
if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
panic!("assertion failed: should_panic; \
non-panicking result: {:?}", v);
}
};
($body:expr, $($arg:tt)*) => {
if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
panic!($($arg)*);
}
};
}

#[test]
fn test_matmul_arcarray() {
let mut A = ArcArray::<usize, _>::zeros((2, 3));
Expand Down Expand Up @@ -328,6 +342,82 @@ fn test_slice_collapse_with_indices() {
assert_eq!(vi, Array3::from_elem((1, 1, 1), elem));
}

#[test]
fn test_multislice() {
macro_rules! do_test {
($arr:expr, $($s:expr),*) => {
{
let arr = $arr;
let copy = arr.clone();
assert_eq!(
arr.multi_slice_mut(($($s,)*)),
($(copy.clone().slice_mut($s),)*)
);
}
};
}

let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap();

assert_eq!(
(arr.clone().view_mut(),),
arr.multi_slice_mut((s![.., ..],)),
);
assert_eq!(arr.multi_slice_mut(()), ());
do_test!(&mut arr, s![0, ..]);
do_test!(&mut arr, s![0, ..], s![1, ..]);
do_test!(&mut arr, s![0, ..], s![-1, ..]);
do_test!(&mut arr, s![0, ..], s![1.., ..]);
do_test!(&mut arr, s![1, ..], s![..;2, ..]);
do_test!(&mut arr, s![..2, ..], s![2.., ..]);
do_test!(&mut arr, s![1..;2, ..], s![..;2, ..]);
do_test!(&mut arr, s![..;-2, ..], s![..;2, ..]);
do_test!(&mut arr, s![..;12, ..], s![3..;3, ..]);
do_test!(&mut arr, s![3, ..], s![..-1;-2, ..]);
do_test!(&mut arr, s![0, ..], s![1, ..], s![2, ..]);
do_test!(&mut arr, s![0, ..], s![1, ..], s![2, ..], s![3, ..]);
}

#[test]
fn test_multislice_intersecting() {
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3.., ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![..;3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![..;6, ..], s![3..;3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![2, ..], s![..-1;-2, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![4, ..], s![3, ..], s![3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![4, ..], s![3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3, ..], s![4, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3, ..], s![4, ..], s![3, ..]));
});
}

#[should_panic]
#[test]
fn index_out_of_bounds() {
Expand Down

0 comments on commit 78846da

Please sign in to comment.