-
Notifications
You must be signed in to change notification settings - Fork 291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stack and concatenate #735
Changes from all commits
b39a0bb
efd686d
8cd5d5a
0e9d65d
6525f6e
2d04320
ced7ed9
443fa86
68f0e95
d997c18
8e12f82
ef3d5b6
1e91fe4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
use crate::error::{from_kind, ErrorKind, ShapeError}; | ||
use crate::imp_prelude::*; | ||
|
||
/// Stack arrays along the given axis. | ||
/// Concatenate arrays along the given axis. | ||
/// | ||
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`. | ||
/// (may be made more flexible in the future).<br> | ||
|
@@ -29,10 +29,11 @@ use crate::imp_prelude::*; | |
/// [3., 3.]])) | ||
/// ); | ||
/// ``` | ||
pub fn stack<'a, A, D>( | ||
axis: Axis, | ||
arrays: &[ArrayView<'a, A, D>], | ||
) -> Result<Array<A, D>, ShapeError> | ||
#[deprecated( | ||
since = "0.13.1", | ||
note = "Please use the `concatenate` function instead" | ||
)] | ||
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError> | ||
where | ||
A: Copy, | ||
D: RemoveAxis, | ||
|
@@ -76,7 +77,79 @@ where | |
Ok(res) | ||
} | ||
|
||
/// Stack arrays along the given axis. | ||
/// Concatenate arrays along the given axis. | ||
/// | ||
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`. | ||
/// (may be made more flexible in the future).<br> | ||
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds, | ||
/// if the result is larger than is possible to represent. | ||
/// | ||
/// ``` | ||
/// use ndarray::{arr2, Axis, concatenate}; | ||
/// | ||
/// let a = arr2(&[[2., 2.], | ||
/// [3., 3.]]); | ||
/// assert!( | ||
/// concatenate(Axis(0), &[a.view(), a.view()]) | ||
/// == Ok(arr2(&[[2., 2.], | ||
/// [3., 3.], | ||
/// [2., 2.], | ||
/// [3., 3.]])) | ||
/// ); | ||
/// ``` | ||
#[allow(deprecated)] | ||
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError> | ||
where | ||
A: Copy, | ||
D: RemoveAxis, | ||
{ | ||
stack(axis, arrays) | ||
} | ||
|
||
pub fn stack_new_axis<A, D>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing doc comment :) |
||
axis: Axis, | ||
arrays: Vec<ArrayView<A, D>>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) -> Result<Array<A, D::Larger>, ShapeError> | ||
where | ||
A: Copy, | ||
D: Dimension, | ||
D::Larger: RemoveAxis, | ||
{ | ||
if arrays.is_empty() { | ||
return Err(from_kind(ErrorKind::Unsupported)); | ||
} | ||
let common_dim = arrays[0].raw_dim(); | ||
// Avoid panic on `insert_axis` call, return an Err instead of it. | ||
if axis.index() > common_dim.ndim() { | ||
return Err(from_kind(ErrorKind::OutOfBounds)); | ||
} | ||
let mut res_dim = common_dim.insert_axis(axis); | ||
|
||
if arrays.iter().any(|a| a.raw_dim() != common_dim) { | ||
return Err(from_kind(ErrorKind::IncompatibleShape)); | ||
} | ||
|
||
res_dim.set_axis(axis, arrays.len()); | ||
|
||
// we can safely use uninitialized values here because they are Copy | ||
// and we will only ever write to them | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good, it follows what we do in stack. We'll revise handling of uninit data, but it's not in scope for this PR. (#796 ) |
||
let size = res_dim.size(); | ||
let mut v = Vec::with_capacity(size); | ||
unsafe { | ||
v.set_len(size); | ||
} | ||
let mut res = Array::from_shape_vec(res_dim, v)?; | ||
|
||
res.axis_iter_mut(axis) | ||
.zip(arrays.into_iter()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this zip seems to show that we would be fine with just using a slice of |
||
.for_each(|(mut assign_view, array)| { | ||
assign_view.assign(&array); | ||
}); | ||
|
||
Ok(res) | ||
} | ||
|
||
/// Concatenate arrays along the given axis. | ||
/// | ||
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each | ||
/// argument `a`. | ||
|
@@ -109,3 +182,71 @@ macro_rules! stack { | |
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The stack macro should also be deprecated |
||
} | ||
} | ||
|
||
/// Concatenate arrays along the given axis. | ||
/// | ||
/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each | ||
/// argument `a`. | ||
/// | ||
/// [1]: fn.concatenate.html | ||
/// | ||
/// ***Panics*** if the `concatenate` function would return an error. | ||
/// | ||
/// ``` | ||
/// extern crate ndarray; | ||
/// | ||
/// use ndarray::{arr2, concatenate, Axis}; | ||
/// | ||
/// # fn main() { | ||
/// | ||
/// let a = arr2(&[[2., 2.], | ||
/// [3., 3.]]); | ||
/// assert!( | ||
/// concatenate![Axis(0), a, a] | ||
/// == arr2(&[[2., 2.], | ||
/// [3., 3.], | ||
/// [2., 2.], | ||
/// [3., 3.]]) | ||
/// ); | ||
/// # } | ||
/// ``` | ||
#[macro_export] | ||
macro_rules! concatenate { | ||
($axis:expr, $( $array:expr ),+ ) => { | ||
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() | ||
} | ||
} | ||
|
||
/// Stack arrays along the new axis. | ||
/// | ||
/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each | ||
/// argument `a`. | ||
/// | ||
/// [1]: fn.stack_new_axis.html | ||
/// | ||
/// ***Panics*** if the `stack` function would return an error. | ||
/// | ||
/// ``` | ||
/// extern crate ndarray; | ||
/// | ||
/// use ndarray::{arr2, arr3, stack_new_axis, Axis}; | ||
/// | ||
/// # fn main() { | ||
/// | ||
/// let a = arr2(&[[2., 2.], | ||
/// [3., 3.]]); | ||
/// assert!( | ||
/// stack_new_axis![Axis(0), a, a] | ||
/// == arr3(&[[[2., 2.], | ||
/// [3., 3.]], | ||
/// [[2., 2.], | ||
/// [3., 3.]]]) | ||
/// ); | ||
/// # } | ||
/// ``` | ||
#[macro_export] | ||
macro_rules! stack_new_axis { | ||
($axis:expr, $( $array:expr ),+ ) => { | ||
$crate::stack_new_axis($axis, vec![ $($crate::ArrayView::from(&$array) ),* ]).unwrap() | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rename of the variable inside the macro is an incidental mistake, and should not be in the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I removed this in my rebase-fix of the branch. No other changes in that branch. :)