From 4501f62acdfbb34d38276effedcb0728166fbec3 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Wed, 31 Aug 2022 12:04:24 -0700 Subject: [PATCH] Add const-generic IndexedParallelIterator::arrays --- src/iter/arrays.rs | 222 +++++++++++++++++++++++++++++++++++++ src/iter/mod.rs | 34 ++++++ tests/clones.rs | 1 + tests/debug.rs | 1 + tests/producer_split_at.rs | 23 ++++ 5 files changed, 281 insertions(+) create mode 100644 src/iter/arrays.rs diff --git a/src/iter/arrays.rs b/src/iter/arrays.rs new file mode 100644 index 000000000..1fac64c77 --- /dev/null +++ b/src/iter/arrays.rs @@ -0,0 +1,222 @@ +#![cfg(has_min_const_generics)] +use super::plumbing::*; +use super::*; + +/// `Arrays` is an iterator that groups elements of an underlying iterator. +/// +/// This struct is created by the [`arrays()`] method on [`IndexedParallelIterator`] +/// +/// [`arrays()`]: trait.IndexedParallelIterator.html#method.arrays +/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html +#[must_use = "iterator adaptors are lazy and do nothing unless consumed"] +#[derive(Debug, Clone)] +pub struct Arrays +where + I: IndexedParallelIterator, +{ + iter: I, +} + +impl Arrays +where + I: IndexedParallelIterator, +{ + /// Creates a new `Arrays` iterator + pub(super) fn new(iter: I) -> Self { + Arrays { iter } + } +} + +impl ParallelIterator for Arrays +where + I: IndexedParallelIterator, +{ + type Item = [I::Item; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl IndexedParallelIterator for Arrays +where + I: IndexedParallelIterator, +{ + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + self.iter.len() / N + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + let len = self.iter.len(); + return self.iter.with_producer(Callback { len, callback }); + + struct Callback { + len: usize, + callback: CB, + } + + impl ProducerCallback for Callback + where + CB: ProducerCallback<[T; N]>, + { + type Output = CB::Output; + + fn callback

(self, base: P) -> CB::Output + where + P: Producer, + { + self.callback.callback(ArrayProducer { + len: self.len, + base, + }) + } + } + } +} + +struct ArrayProducer +where + P: Producer, +{ + len: usize, + base: P, +} + +impl Producer for ArrayProducer +where + P: Producer, +{ + type Item = [P::Item; N]; + type IntoIter = ArraySeq; + + fn into_iter(self) -> Self::IntoIter { + // TODO: we're ignoring any remainder -- should we no-op consume it? + let remainder = self.len % N; + let len = self.len - remainder; + let inner = (len > 0).then(|| self.base.split_at(len).0); + ArraySeq { len, inner } + } + + fn split_at(self, index: usize) -> (Self, Self) { + let elem_index = index * N; + let (left, right) = self.base.split_at(elem_index); + ( + ArrayProducer { + len: elem_index, + base: left, + }, + ArrayProducer { + len: self.len - elem_index, + base: right, + }, + ) + } + + fn min_len(&self) -> usize { + self.base.min_len() / N + } + + fn max_len(&self) -> usize { + self.base.max_len() / N + } +} + +struct ArraySeq { + len: usize, + inner: Option

, +} + +impl Iterator for ArraySeq +where + P: Producer, +{ + type Item = [P::Item; N]; + + fn next(&mut self) -> Option { + let mut producer = self.inner.take()?; + debug_assert!(self.len > 0 && self.len % N == 0); + if self.len > N { + let (left, right) = producer.split_at(N); + producer = left; + self.inner = Some(right); + self.len -= N; + } else { + self.len = 0; + } + Some(collect_array(producer.into_iter())) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl ExactSizeIterator for ArraySeq +where + P: Producer, +{ + #[inline] + fn len(&self) -> usize { + self.len / N + } +} + +impl DoubleEndedIterator for ArraySeq +where + P: Producer, +{ + fn next_back(&mut self) -> Option { + let mut producer = self.inner.take()?; + debug_assert!(self.len > 0 && self.len % N == 0); + if self.len > N { + let (left, right) = producer.split_at(self.len - N); + producer = right; + self.inner = Some(left); + self.len -= N; + } else { + self.len = 0; + } + Some(collect_array(producer.into_iter())) + } +} + +fn collect_array(mut iter: impl ExactSizeIterator) -> [T; N] { + // TODO(MSRV-1.55): consider `[(); N].map(...)` + // TODO(MSRV-1.63): consider `std::array::from_fn` + + use std::mem::MaybeUninit; + + // TODO(MSRV): use `MaybeUninit::uninit_array` when/if it's stabilized. + // SAFETY: We can assume "init" when moving uninit wrappers inward. + let mut array: [MaybeUninit; N] = + unsafe { MaybeUninit::<[MaybeUninit; N]>::uninit().assume_init() }; + + debug_assert_eq!(iter.len(), N); + for i in 0..N { + let item = iter.next().expect("should have N items"); + array[i] = MaybeUninit::new(item); + } + debug_assert!(iter.next().is_none()); + + // TODO(MSRV): use `MaybeUninit::array_assume_init` when/if it's stabilized. + // SAFETY: We've initialized all N items in the array, so we can cast and "move" it. + unsafe { (&array as *const [MaybeUninit; N] as *const [T; N]).read() } +} diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 50cd4dee7..d36047566 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -189,6 +189,10 @@ pub use self::{ zip_eq::ZipEq, }; +mod arrays; +#[cfg(has_min_const_generics)] +pub use self::arrays::Arrays; + mod step_by; #[cfg(has_step_by_rev)] pub use self::step_by::StepBy; @@ -2417,6 +2421,36 @@ pub trait IndexedParallelIterator: ParallelIterator { Chunks::new(self, chunk_size) } + /// Splits an iterator up into fixed-size arrays. + /// + /// Returns an iterator that returns arrays with the given number of elements. + /// If the number of elements in the iterator is not divisible by `N`, + /// the remaining items are ignored. + /// + /// See also [`par_array_chunks()`] and [`par_array_chunks_mut()`] for similar + /// behavior on slices, although they yield array references instead. + /// + /// [`par_array_chunks()`]: ../slice/trait.ParallelSlice.html#method.par_array_chunks + /// [`par_array_chunks_mut()`]: ../slice/trait.ParallelSliceMut.html#method.par_array_chunks_mut + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let a = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + /// let r: Vec<[i32; 3]> = a.into_par_iter().arrays().collect(); + /// assert_eq!(r, vec![[1, 2, 3], [4, 5, 6], [7, 8, 9]]); + /// ``` + /// + /// # Compatibility + /// + /// This method is only available on Rust 1.51 or greater. + #[cfg(has_min_const_generics)] + fn arrays(self) -> Arrays { + assert!(N != 0, "array length must not be zero"); + Arrays::new(self) + } + /// Lexicographically compares the elements of this `ParallelIterator` with those of /// another. /// diff --git a/tests/clones.rs b/tests/clones.rs index 2f512ca05..3823d5da0 100644 --- a/tests/clones.rs +++ b/tests/clones.rs @@ -138,6 +138,7 @@ fn clone_adaptors() { check(v.par_iter().interleave_shortest(&v)); check(v.par_iter().intersperse(&None)); check(v.par_iter().chunks(3)); + check(v.par_iter().arrays::<3>()); check(v.par_iter().map(|x| x)); check(v.par_iter().map_with(0, |_, x| x)); check(v.par_iter().map_init(|| 0, |_, x| x)); diff --git a/tests/debug.rs b/tests/debug.rs index d107b1377..b48777140 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -163,6 +163,7 @@ fn debug_adaptors() { check(v.par_iter().interleave_shortest(&v)); check(v.par_iter().intersperse(&-1)); check(v.par_iter().chunks(3)); + check(v.par_iter().arrays::<3>()); check(v.par_iter().map(|x| x)); check(v.par_iter().map_with(0, |_, x| x)); check(v.par_iter().map_init(|| 0, |_, x| x)); diff --git a/tests/producer_split_at.rs b/tests/producer_split_at.rs index d71050492..80ed07f82 100644 --- a/tests/producer_split_at.rs +++ b/tests/producer_split_at.rs @@ -343,6 +343,29 @@ fn chunks() { check(&v, || s.par_iter().cloned().chunks(2)); } +#[test] +fn arrays() { + use std::convert::TryInto; + fn check_len(s: &[i32]) { + let v: Vec<[_; N]> = s.chunks_exact(N).map(|c| c.try_into().unwrap()).collect(); + check(&v, || s.par_iter().copied().arrays::()); + } + + let s: Vec<_> = (0..10).collect(); + check_len::<1>(&s); + check_len::<2>(&s); + check_len::<3>(&s); + check_len::<4>(&s); + check_len::<5>(&s); + check_len::<6>(&s); + check_len::<7>(&s); + check_len::<8>(&s); + check_len::<9>(&s); + check_len::<10>(&s); + check_len::<11>(&s); + check_len::<12>(&s); +} + #[test] fn map() { let v: Vec<_> = (0..10).collect();