Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack and concatenate #735

Closed
wants to merge 13 commits into from
3 changes: 3 additions & 0 deletions src/doc/ndarray_for_numpy_users/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
Expand Down Expand Up @@ -642,6 +643,8 @@
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
//! [stack!]: ../../macro.stack.html
//! [stack()]: ../../fn.stack.html
//! [stack_new_axis!]: ../../macro.stack_new_axis.html
//! [stack_new_axis()]: ../../fn.stack_new_axis.html
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis
Expand Down
4 changes: 2 additions & 2 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::iter::{
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
};
use crate::slice::MultiSlice;
use crate::stacking::stack;
use crate::stacking::concatenate;
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};

/// # Methods For All Array Types
Expand Down Expand Up @@ -841,7 +841,7 @@ where
dim.set_axis(axis, 0);
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
} else {
stack(axis, &subs).unwrap()
concatenate(axis, &subs).unwrap()
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane

pub use crate::arraytraits::AsArray;
pub use crate::linalg_traits::{LinalgScalar, NdFloat};
pub use crate::stacking::stack;

#[allow(deprecated)]
pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::impl_views::IndexLonger;
pub use crate::shape_builder::ShapeBuilder;
Expand Down
24 changes: 12 additions & 12 deletions src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,62 +544,62 @@ impl_slicenextdim_larger!((), Slice);
#[macro_export]
macro_rules! s(
// convert a..b;c into @convert(a..b, c), final item
(@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => {
(@parse $dim:expr, [$($concatenate:tt)*] $r:expr;$s:expr) => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rename of the variable inside the macro is an incidental mistake, and should not be in the PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I removed this in my rebase-fix of the branch. No other changes in that branch. :)

match $r {
r => {
let out_dim = $crate::SliceNextDim::next_dim(&r, $dim);
#[allow(unsafe_code)]
unsafe {
$crate::SliceInfo::new_unchecked(
[$($stack)* $crate::s!(@convert r, $s)],
[$($concatenate)* $crate::s!(@convert r, $s)],
out_dim,
)
}
}
}
};
// convert a..b into @convert(a..b), final item
(@parse $dim:expr, [$($stack:tt)*] $r:expr) => {
(@parse $dim:expr, [$($concatenate:tt)*] $r:expr) => {
match $r {
r => {
let out_dim = $crate::SliceNextDim::next_dim(&r, $dim);
#[allow(unsafe_code)]
unsafe {
$crate::SliceInfo::new_unchecked(
[$($stack)* $crate::s!(@convert r)],
[$($concatenate)* $crate::s!(@convert r)],
out_dim,
)
}
}
}
};
// convert a..b;c into @convert(a..b, c), final item, trailing comma
(@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => {
$crate::s![@parse $dim, [$($stack)*] $r;$s]
(@parse $dim:expr, [$($concatenate:tt)*] $r:expr;$s:expr ,) => {
$crate::s![@parse $dim, [$($concatenate)*] $r;$s]
};
// convert a..b into @convert(a..b), final item, trailing comma
(@parse $dim:expr, [$($stack:tt)*] $r:expr ,) => {
$crate::s![@parse $dim, [$($stack)*] $r]
(@parse $dim:expr, [$($concatenate:tt)*] $r:expr ,) => {
$crate::s![@parse $dim, [$($concatenate)*] $r]
};
// convert a..b;c into @convert(a..b, c)
(@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => {
(@parse $dim:expr, [$($concatenate:tt)*] $r:expr;$s:expr, $($t:tt)*) => {
match $r {
r => {
$crate::s![@parse
$crate::SliceNextDim::next_dim(&r, $dim),
[$($stack)* $crate::s!(@convert r, $s),]
[$($concatenate)* $crate::s!(@convert r, $s),]
$($t)*
]
}
}
};
// convert a..b into @convert(a..b)
(@parse $dim:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => {
(@parse $dim:expr, [$($concatenate:tt)*] $r:expr, $($t:tt)*) => {
match $r {
r => {
$crate::s![@parse
$crate::SliceNextDim::next_dim(&r, $dim),
[$($stack)* $crate::s!(@convert r),]
[$($concatenate)* $crate::s!(@convert r),]
$($t)*
]
}
Expand Down
153 changes: 147 additions & 6 deletions src/stacking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;

/// Stack arrays along the given axis.
/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
Expand All @@ -29,10 +29,11 @@ use crate::imp_prelude::*;
/// [3., 3.]]))
/// );
/// ```
pub fn stack<'a, A, D>(
axis: Axis,
arrays: &[ArrayView<'a, A, D>],
) -> Result<Array<A, D>, ShapeError>
#[deprecated(
since = "0.13.1",
note = "Please use the `concatenate` function instead"
)]
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
Expand Down Expand Up @@ -76,7 +77,79 @@ where
Ok(res)
}

/// Stack arrays along the given axis.
/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, concatenate};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
#[allow(deprecated)]
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
{
stack(axis, arrays)
}

pub fn stack_new_axis<A, D>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing doc comment :)

axis: Axis,
arrays: Vec<ArrayView<A, D>>,
Copy link
Member

@bluss bluss Apr 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stack_new_axis and concatenate should agree on if they take the arrays as a slice or a Vec, they should have the same interface if they can. Avoiding cloning of the array views seems like a minor issue, so it can be decided either way (slice or vec)

) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Copy,
D: Dimension,
D::Larger: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let common_dim = arrays[0].raw_dim();
// Avoid panic on `insert_axis` call, return an Err instead of it.
if axis.index() > common_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let mut res_dim = common_dim.insert_axis(axis);

if arrays.iter().any(|a| a.raw_dim() != common_dim) {
return Err(from_kind(ErrorKind::IncompatibleShape));
}

res_dim.set_axis(axis, arrays.len());

// we can safely use uninitialized values here because they are Copy
// and we will only ever write to them
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, it follows what we do in stack. We'll revise handling of uninit data, but it's not in scope for this PR. (#796 )

let size = res_dim.size();
let mut v = Vec::with_capacity(size);
unsafe {
v.set_len(size);
}
let mut res = Array::from_shape_vec(res_dim, v)?;

res.axis_iter_mut(axis)
.zip(arrays.into_iter())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this zip seems to show that we would be fine with just using a slice of ArrayView as input, we don't need to use them by value

.for_each(|(mut assign_view, array)| {
assign_view.assign(&array);
});

Ok(res)
}

/// Concatenate arrays along the given axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
Expand Down Expand Up @@ -109,3 +182,71 @@ macro_rules! stack {
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stack macro should also be deprecated

}
}

/// Concatenate arrays along the given axis.
///
/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
///
/// [1]: fn.concatenate.html
///
/// ***Panics*** if the `concatenate` function would return an error.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, concatenate, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// concatenate![Axis(0), a, a]
/// == arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]])
/// );
/// # }
/// ```
#[macro_export]
macro_rules! concatenate {
($axis:expr, $( $array:expr ),+ ) => {
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}

/// Stack arrays along the new axis.
///
/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
///
/// [1]: fn.stack_new_axis.html
///
/// ***Panics*** if the `stack` function would return an error.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack_new_axis![Axis(0), a, a]
/// == arr3(&[[[2., 2.],
/// [3., 3.]],
/// [[2., 2.],
/// [3., 3.]]])
/// );
/// # }
/// ```
#[macro_export]
macro_rules! stack_new_axis {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack_new_axis($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}
45 changes: 43 additions & 2 deletions tests/stacking.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use ndarray::{arr2, aview1, stack, Array2, Axis, ErrorKind};
#![allow(deprecated)]

use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};

#[test]
fn stacking() {
fn concatenating() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
Expand All @@ -23,4 +25,43 @@ fn stacking() {

let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);

let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));

let c = concatenate![Axis(0), a, b];
assert_eq!(
c,
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
);

let d = concatenate![Axis(0), a.row(0), &[9., 9.]];
assert_eq!(d, aview1(&[2., 2., 9., 9.]));

let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::concatenate(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}

#[test]
fn stacking() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack_new_axis(Axis(0), vec![a.view(), a.view()]).unwrap();
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));

let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
let res = ndarray::stack_new_axis(Axis(1), vec![a.view(), c.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);

let res = ndarray::stack_new_axis(Axis(3), vec![a.view(), a.view()]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);

let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), vec![]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}