Skip to content

Commit

Permalink
Merge pull request #1297 from LazaroHurtado/fix/window_stride
Browse files Browse the repository at this point in the history
Updated Windows `base` Computations to be Safer
  • Loading branch information
Nil Goyette committed Jun 18, 2023
2 parents 17a8d25 + 5bcc73e commit 9447328
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 31 deletions.
56 changes: 25 additions & 31 deletions src/iterators/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::imp_prelude::*;
use crate::IntoDimension;
use crate::Layout;
use crate::NdProducer;
use crate::Slice;

/// Window producer and iterable
///
Expand All @@ -24,16 +25,19 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {

let mut unit_stride = D::zeros(ndim);
unit_stride.slice_mut().fill(1);

Windows::new_with_stride(a, window, unit_stride)
}

pub(crate) fn new_with_stride<E>(a: ArrayView<'a, A, D>, window_size: E, strides: E) -> Self
pub(crate) fn new_with_stride<E>(a: ArrayView<'a, A, D>, window_size: E, axis_strides: E) -> Self
where
E: IntoDimension<Dim = D>,
{
let window = window_size.into_dimension();
let strides_d = strides.into_dimension();

let strides = axis_strides.into_dimension();
let window_strides = a.strides.clone();

ndassert!(
a.ndim() == window.ndim(),
concat!(
Expand All @@ -44,45 +48,35 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
a.ndim(),
a.shape()
);

ndassert!(
a.ndim() == strides_d.ndim(),
a.ndim() == strides.ndim(),
concat!(
"Stride dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
strides_d.ndim(),
strides.ndim(),
a.ndim(),
a.shape()
);
let mut size = a.dim;
for ((sz, &ws), &stride) in size
.slice_mut()
.iter_mut()
.zip(window.slice())
.zip(strides_d.slice())
{
assert_ne!(ws, 0, "window-size must not be zero!");
assert_ne!(stride, 0, "stride cannot have a dimension as zero!");
// cannot use std::cmp::max(0, ..) since arithmetic underflow panics
*sz = if *sz < ws {
0
} else {
((*sz - (ws - 1) - 1) / stride) + 1
};
}
let window_strides = a.strides.clone();

let mut array_strides = a.strides.clone();
for (arr_stride, ix_stride) in array_strides.slice_mut().iter_mut().zip(strides_d.slice()) {
*arr_stride *= ix_stride;
}
let mut base = a;
base.slice_each_axis_inplace(|ax_desc| {
let len = ax_desc.len;
let wsz = window[ax_desc.axis.index()];
let stride = strides[ax_desc.axis.index()];

unsafe {
Windows {
base: ArrayView::new(a.ptr, size, array_strides),
window,
strides: window_strides,
if len < wsz {
Slice::new(0, Some(0), 1)
} else {
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
}
});

Windows {
base,
window,
strides: window_strides,
}
}
}
Expand Down
28 changes: 28 additions & 0 deletions tests/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,31 @@ fn test_window_neg_stride() {
answer.iter()
);
}

#[test]
fn test_windows_with_stride_on_inverted_axis() {
let mut array = Array::from_iter(1..17).into_shape((4, 4)).unwrap();

// inverting axis results in negative stride
array.invert_axis(Axis(0));
itertools::assert_equal(
array.windows_with_stride((2, 2), (2,2)),
vec![
arr2(&[[13, 14], [9, 10]]),
arr2(&[[15, 16], [11, 12]]),
arr2(&[[5, 6], [1, 2]]),
arr2(&[[7, 8], [3, 4]]),
],
);

array.invert_axis(Axis(1));
itertools::assert_equal(
array.windows_with_stride((2, 2), (2,2)),
vec![
arr2(&[[16, 15], [12, 11]]),
arr2(&[[14, 13], [10, 9]]),
arr2(&[[8, 7], [4, 3]]),
arr2(&[[6, 5], [2, 1]]),
],
);
}

0 comments on commit 9447328

Please sign in to comment.