Skip to content

Commit

Permalink
Add const-generic par_array_chunks_mut
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Dec 12, 2023
1 parent caad08a commit 12d4f24
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 2 deletions.
94 changes: 93 additions & 1 deletion src/slice/array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::iter::plumbing::*;
use crate::iter::*;

use super::Iter;
use super::{Iter, IterMut};

/// Parallel iterator over immutable non-overlapping chunks of a slice
#[derive(Debug)]
Expand Down Expand Up @@ -78,3 +78,95 @@ impl<'data, T: Sync + 'data, const N: usize> IndexedParallelIterator for ArrayCh
self.iter.with_producer(callback)
}
}

/// Parallel iterator over immutable non-overlapping chunks of a slice
#[derive(Debug)]
pub struct ArrayChunksMut<'data, T: Send, const N: usize> {
iter: IterMut<'data, [T; N]>,
rem: &'data mut [T],
}

impl<'data, T: Send, const N: usize> ArrayChunksMut<'data, T, N> {
pub(super) fn new(slice: &'data mut [T]) -> Self {
assert_ne!(N, 0);
let len = slice.len() / N;
let (fst, snd) = slice.split_at_mut(len * N);
// SAFETY: We cast a slice of `len * N` elements into
// a slice of `len` many `N` elements chunks.
let array_slice: &'data mut [[T; N]] = unsafe {
let ptr = fst.as_mut_ptr() as *mut [T; N];
::std::slice::from_raw_parts_mut(ptr, len)
};
Self {
iter: array_slice.par_iter_mut(),
rem: snd,
}
}

/// Return the remainder of the original slice that is not going to be
/// returned by the iterator. The returned slice has at most `N-1`
/// elements.
///
/// Note that this has to consume `self` to return the original lifetime of
/// the data, which prevents this from actually being used as a parallel
/// iterator since that also consumes. This method is provided for parity
/// with `std::iter::ArrayChunksMut`, but consider calling `remainder()` or
/// `take_remainder()` as alternatives.
pub fn into_remainder(self) -> &'data mut [T] {
self.rem
}

/// Return the remainder of the original slice that is not going to be
/// returned by the iterator. The returned slice has at most `N-1`
/// elements.
///
/// Consider `take_remainder()` if you need access to the data with its
/// original lifetime, rather than borrowing through `&mut self` here.
pub fn remainder(&mut self) -> &mut [T] {
self.rem
}

/// Return the remainder of the original slice that is not going to be
/// returned by the iterator. The returned slice has at most `N-1`
/// elements. Subsequent calls will return an empty slice.
pub fn take_remainder(&mut self) -> &'data mut [T] {
std::mem::replace(&mut self.rem, &mut [])
}
}

impl<'data, T: Send + 'data, const N: usize> ParallelIterator for ArrayChunksMut<'data, T, N> {
type Item = &'data mut [T; N];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl<'data, T: Send + 'data, const N: usize> IndexedParallelIterator
for ArrayChunksMut<'data, T, N>
{
fn drive<C>(self, consumer: C) -> C::Result
where
C: Consumer<Self::Item>,
{
bridge(self, consumer)
}

fn len(&self) -> usize {
self.iter.len()
}

fn with_producer<CB>(self, callback: CB) -> CB::Output
where
CB: ProducerCallback<Self::Item>,
{
self.iter.with_producer(callback)
}
}
22 changes: 21 additions & 1 deletion src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod rchunks;
mod test;

#[cfg(min_const_generics)]
pub use self::array::ArrayChunks;
pub use self::array::{ArrayChunks, ArrayChunksMut};

use self::mergesort::par_mergesort;
use self::quicksort::par_quicksort;
Expand Down Expand Up @@ -297,6 +297,26 @@ pub trait ParallelSliceMut<T: Send> {
RChunksExactMut::new(chunk_size, self.as_parallel_slice_mut())
}

/// Returns a parallel iterator over `N`-element chunks of
/// `self` at a time. The chunks are mutable and do not overlap.
///
/// If `N` does not divide the length of the slice, then the
/// last up to `N-1` elements will be omitted and can be
/// retrieved from the remainder function of the iterator.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let mut array = [1, 2, 3, 4, 5];
/// array.par_array_chunks_mut()
/// .for_each(|[a, _, b]| std::mem::swap(a, b));
/// assert_eq!(array, [3, 2, 1, 4, 5]);
/// ```
fn par_array_chunks_mut<const N: usize>(&mut self) -> ArrayChunksMut<'_, T, N> {
ArrayChunksMut::new(self.as_parallel_slice_mut())
}

/// Sorts the slice in parallel.
///
/// This sort is stable (i.e., does not reorder equal elements) and *O*(*n* \* log(*n*)) worst-case.
Expand Down
14 changes: 14 additions & 0 deletions src/slice/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,17 @@ fn test_par_array_chunks_remainder() {
assert_eq!(c.remainder(), &[4]);
assert_eq!(c.len(), 2);
}

#[test]
fn test_par_array_chunks_mut_remainder() {
let v: &mut [i32] = &mut [0, 1, 2, 3, 4];
let mut c = v.par_array_chunks_mut::<2>();
assert_eq!(c.remainder(), &[4]);
assert_eq!(c.len(), 2);
assert_eq!(c.into_remainder(), &[4]);

let mut c = v.par_array_chunks_mut::<2>();
assert_eq!(c.take_remainder(), &[4]);
assert_eq!(c.take_remainder(), &[]);
assert_eq!(c.len(), 2);
}
1 change: 1 addition & 0 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ fn debug_vec() {
check(v.par_array_chunks::<42>());
check(v.par_chunks_mut(42));
check(v.par_chunks_exact_mut(42));
check(v.par_array_chunks_mut::<42>());
check(v.par_rchunks(42));
check(v.par_rchunks_exact(42));
check(v.par_rchunks_mut(42));
Expand Down
34 changes: 34 additions & 0 deletions tests/producer_split_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,40 @@ fn slice_chunks_exact_mut() {
}
}

#[test]
fn slice_array_chunks_mut() {
use std::convert::{TryFrom, TryInto};
fn check_len<const N: usize>(s: &mut [i32], v: &mut [i32])
where
for<'a> &'a mut [i32; N]: PartialEq + TryFrom<&'a mut [i32]> + std::fmt::Debug,
{
// TODO: use https://github.com/rust-lang/rust/pull/74373 instead.
let expected: Vec<_> = v
.chunks_exact_mut(N)
.map(|s| s.try_into().ok().unwrap())
.collect();
map_triples(expected.len() + 1, |i, j, k| {
Split::forward(s.par_array_chunks_mut::<N>(), i, j, k, &expected);
Split::reverse(s.par_array_chunks_mut::<N>(), i, j, k, &expected);
});
}

let mut s: Vec<_> = (0..10).collect();
let mut v: Vec<_> = s.clone();
check_len::<1>(&mut s, &mut v);
check_len::<2>(&mut s, &mut v);
check_len::<3>(&mut s, &mut v);
check_len::<4>(&mut s, &mut v);
check_len::<5>(&mut s, &mut v);
check_len::<6>(&mut s, &mut v);
check_len::<7>(&mut s, &mut v);
check_len::<8>(&mut s, &mut v);
check_len::<9>(&mut s, &mut v);
check_len::<10>(&mut s, &mut v);
check_len::<11>(&mut s, &mut v);
check_len::<12>(&mut s, &mut v);
}

#[test]
fn slice_rchunks() {
let s: Vec<_> = (0..10).collect();
Expand Down

0 comments on commit 12d4f24

Please sign in to comment.