Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add const-generic IndexedParallelIterator::arrays #974

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions src/iter/arrays.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
use super::plumbing::*;
use super::*;

/// `Arrays` is an iterator that groups elements of an underlying iterator.
///
/// This struct is created by the [`arrays()`] method on [`IndexedParallelIterator`]
///
/// [`arrays()`]: trait.IndexedParallelIterator.html#method.arrays
/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
#[derive(Debug, Clone)]
pub struct Arrays<I, const N: usize>
where
I: IndexedParallelIterator,
{
iter: I,
}

impl<I, const N: usize> Arrays<I, N>
where
I: IndexedParallelIterator,
{
/// Creates a new `Arrays` iterator
pub(super) fn new(iter: I) -> Self {
Arrays { iter }
}
}

impl<I, const N: usize> ParallelIterator for Arrays<I, N>
where
I: IndexedParallelIterator,
{
type Item = [I::Item; N];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: Consumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl<I, const N: usize> IndexedParallelIterator for Arrays<I, N>
where
I: IndexedParallelIterator,
{
fn drive<C>(self, consumer: C) -> C::Result
where
C: Consumer<Self::Item>,
{
bridge(self, consumer)
}

fn len(&self) -> usize {
self.iter.len() / N
}

fn with_producer<CB>(self, callback: CB) -> CB::Output
where
CB: ProducerCallback<Self::Item>,
{
let len = self.iter.len();
return self.iter.with_producer(Callback { len, callback });

struct Callback<CB, const N: usize> {
len: usize,
callback: CB,
}

impl<T, CB, const N: usize> ProducerCallback<T> for Callback<CB, N>
where
CB: ProducerCallback<[T; N]>,
{
type Output = CB::Output;

fn callback<P>(self, base: P) -> CB::Output
where
P: Producer<Item = T>,
{
self.callback.callback(ArrayProducer {
len: self.len,
base,
})
}
}
}
}

struct ArrayProducer<P, const N: usize>
where
P: Producer,
{
len: usize,
base: P,
}

impl<P, const N: usize> Producer for ArrayProducer<P, N>
where
P: Producer,
{
type Item = [P::Item; N];
type IntoIter = ArraySeq<P, N>;

fn into_iter(self) -> Self::IntoIter {
// TODO: we're ignoring any remainder -- should we no-op consume it?
let remainder = self.len % N;
let len = self.len - remainder;
let inner = (len > 0).then(|| self.base.split_at(len).0);
ArraySeq { len, inner }
}

fn split_at(self, index: usize) -> (Self, Self) {
let elem_index = index * N;
let (left, right) = self.base.split_at(elem_index);
(
ArrayProducer {
len: elem_index,
base: left,
},
ArrayProducer {
len: self.len - elem_index,
base: right,
},
)
}

fn min_len(&self) -> usize {
self.base.min_len() / N
}

fn max_len(&self) -> usize {
self.base.max_len() / N
}
}

struct ArraySeq<P, const N: usize> {
len: usize,
inner: Option<P>,
}

impl<P, const N: usize> Iterator for ArraySeq<P, N>
where
P: Producer,
{
type Item = [P::Item; N];

fn next(&mut self) -> Option<Self::Item> {
let mut producer = self.inner.take()?;
debug_assert!(self.len > 0 && self.len % N == 0);
if self.len > N {
let (left, right) = producer.split_at(N);
producer = left;
self.inner = Some(right);
self.len -= N;
} else {
self.len = 0;
}
Some(collect_array(producer.into_iter()))
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}

impl<P, const N: usize> ExactSizeIterator for ArraySeq<P, N>
where
P: Producer,
{
#[inline]
fn len(&self) -> usize {
self.len / N
}
}

impl<P, const N: usize> DoubleEndedIterator for ArraySeq<P, N>
where
P: Producer,
{
fn next_back(&mut self) -> Option<Self::Item> {
let mut producer = self.inner.take()?;
debug_assert!(self.len > 0 && self.len % N == 0);
if self.len > N {
let (left, right) = producer.split_at(self.len - N);
producer = right;
self.inner = Some(left);
self.len -= N;
} else {
self.len = 0;
}
Some(collect_array(producer.into_iter()))
}
}

fn collect_array<T, const N: usize>(mut iter: impl ExactSizeIterator<Item = T>) -> [T; N] {
debug_assert_eq!(iter.len(), N);
let array = std::array::from_fn(|_| iter.next().expect("should have N items"));
debug_assert!(iter.next().is_none());
array
}
28 changes: 28 additions & 0 deletions src/iter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ mod test;
// e.g. `find::find()`, are always used **prefixed**, so that they
// can be readily distinguished.

mod arrays;
mod chain;
mod chunks;
mod cloned;
Expand Down Expand Up @@ -159,6 +160,7 @@ mod zip;
mod zip_eq;

pub use self::{
arrays::Arrays,
chain::Chain,
chunks::Chunks,
cloned::Cloned,
Expand Down Expand Up @@ -2544,6 +2546,32 @@ pub trait IndexedParallelIterator: ParallelIterator {
InterleaveShortest::new(self, other.into_par_iter())
}

/// Splits an iterator up into fixed-size arrays.
///
/// Returns an iterator that returns arrays with the given number of elements.
/// If the number of elements in the iterator is not divisible by `N`,
/// the remaining items are ignored.
///
/// See also [`par_array_chunks()`] and [`par_array_chunks_mut()`] for similar
/// behavior on slices, although they yield array references instead.
///
/// [`par_array_chunks()`]: ../slice/trait.ParallelSlice.html#method.par_array_chunks
/// [`par_array_chunks_mut()`]: ../slice/trait.ParallelSliceMut.html#method.par_array_chunks_mut
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let a = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
/// let r: Vec<[i32; 3]> = a.into_par_iter().arrays().collect();
/// assert_eq!(r, vec![[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
/// ```
#[track_caller]
fn arrays<const N: usize>(self) -> Arrays<Self, N> {
assert!(N != 0, "array length must not be zero");
Arrays::new(self)
}

/// Splits an iterator up into fixed-size chunks.
///
/// Returns an iterator that returns `Vec`s of the given number of elements.
Expand Down
1 change: 1 addition & 0 deletions tests/clones.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ fn clone_adaptors() {
check(v.par_iter().interleave_shortest(&v));
check(v.par_iter().intersperse(&None));
check(v.par_iter().chunks(3));
check(v.par_iter().arrays::<3>());
check(v.par_iter().map(|x| x));
check(v.par_iter().map_with(0, |_, x| x));
check(v.par_iter().map_init(|| 0, |_, x| x));
Expand Down
1 change: 1 addition & 0 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ fn debug_adaptors() {
check(v.par_iter().interleave_shortest(&v));
check(v.par_iter().intersperse(&-1));
check(v.par_iter().chunks(3));
check(v.par_iter().arrays::<3>());
check(v.par_iter().map(|x| x));
check(v.par_iter().map_with(0, |_, x| x));
check(v.par_iter().map_init(|| 0, |_, x| x));
Expand Down
23 changes: 23 additions & 0 deletions tests/producer_split_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,29 @@ fn chunks() {
check(&v, || s.par_iter().cloned().chunks(2));
}

#[test]
fn arrays() {
use std::convert::TryInto;
fn check_len<const N: usize>(s: &[i32]) {
let v: Vec<[_; N]> = s.chunks_exact(N).map(|c| c.try_into().unwrap()).collect();
check(&v, || s.par_iter().copied().arrays::<N>());
}

let s: Vec<_> = (0..10).collect();
check_len::<1>(&s);
check_len::<2>(&s);
check_len::<3>(&s);
check_len::<4>(&s);
check_len::<5>(&s);
check_len::<6>(&s);
check_len::<7>(&s);
check_len::<8>(&s);
check_len::<9>(&s);
check_len::<10>(&s);
check_len::<11>(&s);
check_len::<12>(&s);
}

#[test]
fn map() {
let v: Vec<_> = (0..10).collect();
Expand Down