From 5e16c5e20009bc9fa46c568f725fdd70abab121c Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Wed, 31 Aug 2022 12:04:24 -0700 Subject: [PATCH 1/2] Add const-generic IndexedParallelIterator::arrays --- src/iter/arrays.rs | 221 +++++++++++++++++++++++++++++++++++++ src/iter/mod.rs | 28 +++++ tests/clones.rs | 1 + tests/debug.rs | 1 + tests/producer_split_at.rs | 23 ++++ 5 files changed, 274 insertions(+) create mode 100644 src/iter/arrays.rs diff --git a/src/iter/arrays.rs b/src/iter/arrays.rs new file mode 100644 index 000000000..d0adc951e --- /dev/null +++ b/src/iter/arrays.rs @@ -0,0 +1,221 @@ +use super::plumbing::*; +use super::*; + +/// `Arrays` is an iterator that groups elements of an underlying iterator. +/// +/// This struct is created by the [`arrays()`] method on [`IndexedParallelIterator`] +/// +/// [`arrays()`]: trait.IndexedParallelIterator.html#method.arrays +/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html +#[must_use = "iterator adaptors are lazy and do nothing unless consumed"] +#[derive(Debug, Clone)] +pub struct Arrays +where + I: IndexedParallelIterator, +{ + iter: I, +} + +impl Arrays +where + I: IndexedParallelIterator, +{ + /// Creates a new `Arrays` iterator + pub(super) fn new(iter: I) -> Self { + Arrays { iter } + } +} + +impl ParallelIterator for Arrays +where + I: IndexedParallelIterator, +{ + type Item = [I::Item; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl IndexedParallelIterator for Arrays +where + I: IndexedParallelIterator, +{ + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + self.iter.len() / N + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + let len = self.iter.len(); + return self.iter.with_producer(Callback { len, callback }); + + struct Callback { + len: usize, + callback: CB, + } + + impl ProducerCallback for Callback + where + CB: ProducerCallback<[T; N]>, + { + type Output = CB::Output; + + fn callback

(self, base: P) -> CB::Output + where + P: Producer, + { + self.callback.callback(ArrayProducer { + len: self.len, + base, + }) + } + } + } +} + +struct ArrayProducer +where + P: Producer, +{ + len: usize, + base: P, +} + +impl Producer for ArrayProducer +where + P: Producer, +{ + type Item = [P::Item; N]; + type IntoIter = ArraySeq; + + fn into_iter(self) -> Self::IntoIter { + // TODO: we're ignoring any remainder -- should we no-op consume it? + let remainder = self.len % N; + let len = self.len - remainder; + let inner = (len > 0).then(|| self.base.split_at(len).0); + ArraySeq { len, inner } + } + + fn split_at(self, index: usize) -> (Self, Self) { + let elem_index = index * N; + let (left, right) = self.base.split_at(elem_index); + ( + ArrayProducer { + len: elem_index, + base: left, + }, + ArrayProducer { + len: self.len - elem_index, + base: right, + }, + ) + } + + fn min_len(&self) -> usize { + self.base.min_len() / N + } + + fn max_len(&self) -> usize { + self.base.max_len() / N + } +} + +struct ArraySeq { + len: usize, + inner: Option

, +} + +impl Iterator for ArraySeq +where + P: Producer, +{ + type Item = [P::Item; N]; + + fn next(&mut self) -> Option { + let mut producer = self.inner.take()?; + debug_assert!(self.len > 0 && self.len % N == 0); + if self.len > N { + let (left, right) = producer.split_at(N); + producer = left; + self.inner = Some(right); + self.len -= N; + } else { + self.len = 0; + } + Some(collect_array(producer.into_iter())) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl ExactSizeIterator for ArraySeq +where + P: Producer, +{ + #[inline] + fn len(&self) -> usize { + self.len / N + } +} + +impl DoubleEndedIterator for ArraySeq +where + P: Producer, +{ + fn next_back(&mut self) -> Option { + let mut producer = self.inner.take()?; + debug_assert!(self.len > 0 && self.len % N == 0); + if self.len > N { + let (left, right) = producer.split_at(self.len - N); + producer = right; + self.inner = Some(left); + self.len -= N; + } else { + self.len = 0; + } + Some(collect_array(producer.into_iter())) + } +} + +fn collect_array(mut iter: impl ExactSizeIterator) -> [T; N] { + // TODO(MSRV-1.55): consider `[(); N].map(...)` + // TODO(MSRV-1.63): consider `std::array::from_fn` + + use std::mem::MaybeUninit; + + // TODO(MSRV): use `MaybeUninit::uninit_array` when/if it's stabilized. + // SAFETY: We can assume "init" when moving uninit wrappers inward. + let mut array: [MaybeUninit; N] = + unsafe { MaybeUninit::<[MaybeUninit; N]>::uninit().assume_init() }; + + debug_assert_eq!(iter.len(), N); + for i in 0..N { + let item = iter.next().expect("should have N items"); + array[i] = MaybeUninit::new(item); + } + debug_assert!(iter.next().is_none()); + + // TODO(MSRV): use `MaybeUninit::array_assume_init` when/if it's stabilized. + // SAFETY: We've initialized all N items in the array, so we can cast and "move" it. + unsafe { (&array as *const [MaybeUninit; N] as *const [T; N]).read() } +} diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 7b5a29aeb..75509bd63 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -102,6 +102,7 @@ mod test; // e.g. `find::find()`, are always used **prefixed**, so that they // can be readily distinguished. +mod arrays; mod chain; mod chunks; mod cloned; @@ -159,6 +160,7 @@ mod zip; mod zip_eq; pub use self::{ + arrays::Arrays, chain::Chain, chunks::Chunks, cloned::Cloned, @@ -2544,6 +2546,32 @@ pub trait IndexedParallelIterator: ParallelIterator { InterleaveShortest::new(self, other.into_par_iter()) } + /// Splits an iterator up into fixed-size arrays. + /// + /// Returns an iterator that returns arrays with the given number of elements. + /// If the number of elements in the iterator is not divisible by `N`, + /// the remaining items are ignored. + /// + /// See also [`par_array_chunks()`] and [`par_array_chunks_mut()`] for similar + /// behavior on slices, although they yield array references instead. + /// + /// [`par_array_chunks()`]: ../slice/trait.ParallelSlice.html#method.par_array_chunks + /// [`par_array_chunks_mut()`]: ../slice/trait.ParallelSliceMut.html#method.par_array_chunks_mut + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let a = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + /// let r: Vec<[i32; 3]> = a.into_par_iter().arrays().collect(); + /// assert_eq!(r, vec![[1, 2, 3], [4, 5, 6], [7, 8, 9]]); + /// ``` + #[track_caller] + fn arrays(self) -> Arrays { + assert!(N != 0, "array length must not be zero"); + Arrays::new(self) + } + /// Splits an iterator up into fixed-size chunks. /// /// Returns an iterator that returns `Vec`s of the given number of elements. diff --git a/tests/clones.rs b/tests/clones.rs index 0d6c86487..da0cd1476 100644 --- a/tests/clones.rs +++ b/tests/clones.rs @@ -151,6 +151,7 @@ fn clone_adaptors() { check(v.par_iter().interleave_shortest(&v)); check(v.par_iter().intersperse(&None)); check(v.par_iter().chunks(3)); + check(v.par_iter().arrays::<3>()); check(v.par_iter().map(|x| x)); check(v.par_iter().map_with(0, |_, x| x)); check(v.par_iter().map_init(|| 0, |_, x| x)); diff --git a/tests/debug.rs b/tests/debug.rs index 14f37917b..2705543eb 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -165,6 +165,7 @@ fn debug_adaptors() { check(v.par_iter().interleave_shortest(&v)); check(v.par_iter().intersperse(&-1)); check(v.par_iter().chunks(3)); + check(v.par_iter().arrays::<3>()); check(v.par_iter().map(|x| x)); check(v.par_iter().map_with(0, |_, x| x)); check(v.par_iter().map_init(|| 0, |_, x| x)); diff --git a/tests/producer_split_at.rs b/tests/producer_split_at.rs index d71050492..80ed07f82 100644 --- a/tests/producer_split_at.rs +++ b/tests/producer_split_at.rs @@ -343,6 +343,29 @@ fn chunks() { check(&v, || s.par_iter().cloned().chunks(2)); } +#[test] +fn arrays() { + use std::convert::TryInto; + fn check_len(s: &[i32]) { + let v: Vec<[_; N]> = s.chunks_exact(N).map(|c| c.try_into().unwrap()).collect(); + check(&v, || s.par_iter().copied().arrays::()); + } + + let s: Vec<_> = (0..10).collect(); + check_len::<1>(&s); + check_len::<2>(&s); + check_len::<3>(&s); + check_len::<4>(&s); + check_len::<5>(&s); + check_len::<6>(&s); + check_len::<7>(&s); + check_len::<8>(&s); + check_len::<9>(&s); + check_len::<10>(&s); + check_len::<11>(&s); + check_len::<12>(&s); +} + #[test] fn map() { let v: Vec<_> = (0..10).collect(); From 77d6f87e8436964542b6e290635bfceb48db257d Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Tue, 12 Dec 2023 16:15:23 -0800 Subject: [PATCH 2/2] Use array::from_fn --- src/iter/arrays.rs | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/iter/arrays.rs b/src/iter/arrays.rs index d0adc951e..a60f2ba74 100644 --- a/src/iter/arrays.rs +++ b/src/iter/arrays.rs @@ -198,24 +198,8 @@ where } fn collect_array(mut iter: impl ExactSizeIterator) -> [T; N] { - // TODO(MSRV-1.55): consider `[(); N].map(...)` - // TODO(MSRV-1.63): consider `std::array::from_fn` - - use std::mem::MaybeUninit; - - // TODO(MSRV): use `MaybeUninit::uninit_array` when/if it's stabilized. - // SAFETY: We can assume "init" when moving uninit wrappers inward. - let mut array: [MaybeUninit; N] = - unsafe { MaybeUninit::<[MaybeUninit; N]>::uninit().assume_init() }; - debug_assert_eq!(iter.len(), N); - for i in 0..N { - let item = iter.next().expect("should have N items"); - array[i] = MaybeUninit::new(item); - } + let array = std::array::from_fn(|_| iter.next().expect("should have N items")); debug_assert!(iter.next().is_none()); - - // TODO(MSRV): use `MaybeUninit::array_assume_init` when/if it's stabilized. - // SAFETY: We've initialized all N items in the array, so we can cast and "move" it. - unsafe { (&array as *const [MaybeUninit; N] as *const [T; N]).read() } + array }