diff --git a/src/slice/array.rs b/src/slice/array.rs index e4577e910..f02937395 100644 --- a/src/slice/array.rs +++ b/src/slice/array.rs @@ -3,7 +3,7 @@ use crate::iter::plumbing::*; use crate::iter::*; -use super::{Iter, IterMut}; +use super::{Iter, IterMut, ParallelSlice}; /// Parallel iterator over immutable non-overlapping chunks of a slice #[derive(Debug)] @@ -172,3 +172,67 @@ impl<'data, T: Send + 'data, const N: usize> IndexedParallelIterator self.iter.with_producer(callback) } } + +/// Parallel iterator over immutable overlapping windows of a slice +#[derive(Debug)] +pub struct ArrayWindows<'data, T: Sync, const N: usize> { + slice: &'data [T], +} + +impl<'data, T: Sync, const N: usize> ArrayWindows<'data, T, N> { + pub(super) fn new(slice: &'data [T]) -> Self { + ArrayWindows { slice } + } +} + +impl<'data, T: Sync, const N: usize> Clone for ArrayWindows<'data, T, N> { + fn clone(&self) -> Self { + ArrayWindows { ..*self } + } +} + +impl<'data, T: Sync + 'data, const N: usize> ParallelIterator for ArrayWindows<'data, T, N> { + type Item = &'data [T; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl<'data, T: Sync + 'data, const N: usize> IndexedParallelIterator for ArrayWindows<'data, T, N> { + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + assert!(N >= 1); + self.slice.len().saturating_sub(N - 1) + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + fn array(slice: &[T]) -> &[T; N] { + debug_assert_eq!(slice.len(), N); + let ptr = slice.as_ptr() as *const [T; N]; + unsafe { &*ptr } + } + + // FIXME: use our own producer and the standard `array_windows`, rust-lang/rust#75027 + self.slice + .par_windows(N) + .map(array::) + .with_producer(callback) + } +} diff --git a/src/slice/mod.rs b/src/slice/mod.rs index e759f0b92..ef04d5b4a 100644 --- a/src/slice/mod.rs +++ b/src/slice/mod.rs @@ -13,8 +13,8 @@ mod rchunks; mod test; -#[cfg(min_const_generics)] -pub use self::array::{ArrayChunks, ArrayChunksMut}; +#[cfg(has_min_const_generics)] +pub use self::array::{ArrayChunks, ArrayChunksMut, ArrayWindows}; use self::mergesort::par_mergesort; use self::quicksort::par_quicksort; @@ -75,6 +75,21 @@ pub trait ParallelSlice { } } + /// Returns a parallel iterator over all contiguous array windows of + /// length `N`. The windows overlap. + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let windows: Vec<_> = [1, 2, 3].par_array_windows().collect(); + /// assert_eq!(vec![&[1, 2], &[2, 3]], windows); + /// ``` + #[cfg(has_min_const_generics)] + fn par_array_windows(&self) -> ArrayWindows<'_, T, N> { + ArrayWindows::new(self.as_parallel_slice()) + } + /// Returns a parallel iterator over at most `chunk_size` elements of /// `self` at a time. The chunks do not overlap. /// diff --git a/tests/clones.rs b/tests/clones.rs index 36ec0b533..e5a5a85f2 100644 --- a/tests/clones.rs +++ b/tests/clones.rs @@ -106,6 +106,7 @@ fn clone_vec() { check(v.par_rchunks_exact(42)); check(v.par_array_chunks::<42>()); check(v.par_windows(42)); + check(v.par_array_windows::<42>()); check(v.par_split(|x| x % 3 == 0)); check(v.into_par_iter()); } diff --git a/tests/debug.rs b/tests/debug.rs index f2ef462c7..650073ca9 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -130,6 +130,7 @@ fn debug_vec() { check(v.par_rchunks_mut(42)); check(v.par_rchunks_exact_mut(42)); check(v.par_windows(42)); + check(v.par_array_windows::<42>()); check(v.par_split(|x| x % 3 == 0)); check(v.par_split_mut(|x| x % 3 == 0)); check(v.par_drain(..)); diff --git a/tests/producer_split_at.rs b/tests/producer_split_at.rs index 592720d1f..74e6101ad 100644 --- a/tests/producer_split_at.rs +++ b/tests/producer_split_at.rs @@ -328,6 +328,15 @@ fn slice_windows() { check(&v, || s.par_windows(2)); } +#[test] +fn slice_array_windows() { + use std::convert::TryInto; + let s: Vec<_> = (0..10).collect(); + // FIXME: use the standard `array_windows`, rust-lang/rust#75027 + let v: Vec<&[_; 2]> = s.windows(2).map(|s| s.try_into().unwrap()).collect(); + check(&v, || s.par_array_windows::<2>()); +} + #[test] fn vec() { let v: Vec<_> = (0..10).collect();