From e9deaf3806cf40cd852605afacca9b28ab75a625 Mon Sep 17 00:00:00 2001 From: Lazaro Hurtado Date: Tue, 27 Dec 2022 15:39:49 -0500 Subject: [PATCH] added stride support to windows --- src/impl_methods.rs | 52 +++++++++++++++++++++++++++++ src/iterators/windows.rs | 59 ++++++++++++++++++++++++++++++++ tests/windows.rs | 72 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 182 insertions(+), 1 deletion(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 115cd2d71..5ed5bc807 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1457,6 +1457,58 @@ where Windows::new(self.view(), window_size) } + /// Return a window producer and iterable. + /// + /// The windows are all distinct views of size `window_size` + /// that fit into the array's shape. + /// + /// The stride is ordered by the inner most axis.
+ /// Hence, a (x₀, x₁, ..., xₙ) stride will be applied to + /// (Aₙ, ..., A₁, A₀) where Aₓ stands for `Axis(x)`. + /// + /// This produces all windows that fit within the array for the given stride, + /// assuming the window size is not larger than the array size. + /// + /// The produced element is an `ArrayView` with exactly the dimension + /// `window_size`. + /// + /// Note that passing a stride of only ones is similar to + /// calling [`ArrayBase::windows()`]. + /// + /// **Panics** if any dimension of `window_size` or `stride` is zero.
+ /// (**Panics** if `D` is `IxDyn` and `window_size` or `stride` does not match the + /// number of array axes.) + /// + /// This is the same illustration found in [`ArrayBase::windows()`], + /// 2×2 windows in a 3×4 array, but now with a 2x1 stride: + /// + /// ```text + /// ──▶ Axis(1) + /// + /// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓ + /// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃ + /// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ + /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ + /// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────┼─────╄━━━━━╇━━━━━┩ + /// │ │ │ │ │ │ │ │ │ │ + /// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ + /// + /// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ + /// │ │ │ │ │ │ │ │ │ │ + /// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────┼─────╆━━━━━╈━━━━━┪ + /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ + /// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ + /// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃ + /// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┴─────┺━━━━━┻━━━━━┛ + /// ``` + pub fn windows_with_stride(&self, window_size: E, stride: E) -> Windows<'_, A, D> + where + E: IntoDimension, + S: Data, + { + Windows::new_with_stride(self.view(), window_size, stride) + } + /// Returns a producer which traverses over all windows of a given length along an axis. /// /// The windows are all distinct, possibly-overlapping views. The shape of each window diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index c47bfecec..eb8cba77c 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -47,6 +47,65 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> { } } } + + pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, window_size: E, strides: E) -> Self + where + E: IntoDimension, + { + let window = window_size.into_dimension(); + let strides_d = strides.into_dimension(); + ndassert!( + a.ndim() == window.ndim(), + concat!( + "Window dimension {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + window.ndim(), + a.ndim(), + a.shape() + ); + ndassert!( + a.ndim() == strides_d.ndim(), + concat!( + "Stride dimension {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + strides_d.ndim(), + a.ndim(), + a.shape() + ); + let mut size = a.dim; + for ((sz, &ws), &stride) in size + .slice_mut() + .iter_mut() + .zip(window.slice()) + .zip(strides_d.slice().iter().rev()) + { + assert_ne!(ws, 0, "window-size must not be zero!"); + assert_ne!(stride, 0, "stride cannot have a dimension as zero!"); + // cannot use std::cmp::max(0, ..) since arithmetic underflow panics + *sz = if *sz < ws { + 0 + } else { + ((*sz - (ws - 1) - 1) / stride) + 1 + }; + } + let window_strides = a.strides.clone(); + + let mut array_strides = a.strides; + let stride_ndim = array_strides.ndim(); + for (ix, ix_stride) in strides_d.slice().iter().enumerate() { + array_strides[stride_ndim - ix - 1] *= ix_stride; + } + + unsafe { + Windows { + base: ArrayView::new(a.ptr, size, array_strides), + window, + strides: window_strides, + } + } + } } impl_ndproducer! { diff --git a/tests/windows.rs b/tests/windows.rs index 432be5e41..d36e8e0eb 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -30,7 +30,7 @@ fn windows_iterator_zero_size() { a.windows(Dim((0, 0, 0))); } -/// Test that verifites that no windows are yielded on oversized window sizes. +/// Test that verifies that no windows are yielded on oversized window sizes. #[test] fn windows_iterator_oversized() { let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap(); @@ -95,6 +95,76 @@ fn windows_iterator_3d() { ); } +/// Test that verifies the `Windows` iterator panics when stride has an axis equal to zero. +#[test] +#[should_panic] +fn windows_iterator_stride_axis_zero() { + let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap(); + a.windows_with_stride(Dim((2, 2, 2)), Dim((2,2,0))); +} + +/// Test that verifies that only first window is yielded when stride is oversized on every axis. +#[test] +fn windows_iterator_only_one_valid_window_for_oversized_stride() { + let a = Array::from_iter(10..135).into_shape((5, 5, 5)).unwrap(); + let mut iter = a.windows_with_stride((2, 2, 2), (8, 8, 8)).into_iter(); // (4,3,2) doesn't fit into (3,3,3) => oversized! + itertools::assert_equal( + iter.next(), + Some(arr3(&[[[10, 11], [15, 16]],[[35, 36], [40, 41]]])) + ); +} + +/// Simple test for iterating 1d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_1d_with_stride() { + let a = Array::from_iter(10..20).into_shape(10).unwrap(); + itertools::assert_equal( + a.windows_with_stride(Dim(4), Dim(2)), + vec![ + arr1(&[10, 11, 12, 13]), + arr1(&[12, 13, 14, 15]), + arr1(&[14, 15, 16, 17]), + arr1(&[16, 17, 18, 19]), + ], + ); +} + +/// Simple test for iterating 2d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_2d_with_stride() { + let a = Array::from_iter(10..30).into_shape((5, 4)).unwrap(); + itertools::assert_equal( + a.windows_with_stride(Dim((3, 2)), Dim((1,2))), + vec![ + arr2(&[[10, 11], [14, 15], [18, 19]]), + arr2(&[[11, 12], [15, 16], [19, 20]]), + arr2(&[[12, 13], [16, 17], [20, 21]]), + arr2(&[[18, 19], [22, 23], [26, 27]]), + arr2(&[[19, 20], [23, 24], [27, 28]]), + arr2(&[[20, 21], [24, 25], [28, 29]]), + ], + ); +} + +/// Simple test for iterating 3d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_3d_with_stride() { + let a = Array::from_iter(10..74).into_shape((4, 4, 4)).unwrap(); + itertools::assert_equal( + a.windows_with_stride(Dim((2, 2, 2)), Dim((2,2,2))), + vec![ + arr3(&[[[10, 11], [14, 15]], [[26, 27], [30, 31]]]), + arr3(&[[[12, 13], [16, 17]], [[28, 29], [32, 33]]]), + arr3(&[[[18, 19], [22, 23]], [[34, 35], [38, 39]]]), + arr3(&[[[20, 21], [24, 25]], [[36, 37], [40, 41]]]), + arr3(&[[[42, 43], [46, 47]], [[58, 59], [62, 63]]]), + arr3(&[[[44, 45], [48, 49]], [[60, 61], [64, 65]]]), + arr3(&[[[50, 51], [54, 55]], [[66, 67], [70, 71]]]), + arr3(&[[[52, 53], [56, 57]], [[68, 69], [72, 73]]]), + ], + ); +} + #[test] fn test_window_zip() { let a = Array::from_iter(0..64).into_shape((4, 4, 4)).unwrap();