Skip to content

Commit

Permalink
Auto merge of #121204 - cuviper:flatten-one-shot, r=the8472
Browse files Browse the repository at this point in the history
Specialize flattening iterators with only one inner item

For iterators like `Once` and `option::IntoIter` that only ever have a
single item at most, the front and back iterator states in `FlatMap` and
`Flatten` are a waste, as they're always consumed already. We can use
specialization for these types to simplify the iterator methods.

It's a somewhat common pattern to use `flatten()` for options and
results, even recommended by [multiple][1] [clippy][2] [lints][3]. The
implementation is more efficient with `filter_map`, as mentioned in
[clippy#9377], but this new specialization should close some of that
gap for existing code that flattens.

[1]: https://rust-lang.github.io/rust-clippy/master/#filter_map_identity
[2]: https://rust-lang.github.io/rust-clippy/master/#option_filter_map
[3]: https://rust-lang.github.io/rust-clippy/master/#result_filter_map
[clippy#9377]: rust-lang/rust-clippy#9377
  • Loading branch information
bors committed Feb 17, 2024
2 parents cabdf3a + c36ae93 commit 6672c16
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 12 deletions.
221 changes: 209 additions & 12 deletions library/core/src/iter/adapters/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::iter::{
Cloned, Copied, Filter, FilterMap, Fuse, FusedIterator, InPlaceIterable, Map, TrustedFused,
TrustedLen,
};
use crate::iter::{Once, OnceWith};
use crate::iter::{Empty, Once, OnceWith};
use crate::num::NonZero;
use crate::ops::{ControlFlow, Try};
use crate::result;
Expand Down Expand Up @@ -593,6 +593,7 @@ where
}
}

// See also the `OneShot` specialization below.
impl<I, U> Iterator for FlattenCompat<I, U>
where
I: Iterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
Expand All @@ -601,7 +602,7 @@ where
type Item = U::Item;

#[inline]
fn next(&mut self) -> Option<U::Item> {
default fn next(&mut self) -> Option<U::Item> {
loop {
if let elt @ Some(_) = and_then_or_clear(&mut self.frontiter, Iterator::next) {
return elt;
Expand All @@ -614,7 +615,7 @@ where
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
default fn size_hint(&self) -> (usize, Option<usize>) {
let (flo, fhi) = self.frontiter.as_ref().map_or((0, Some(0)), U::size_hint);
let (blo, bhi) = self.backiter.as_ref().map_or((0, Some(0)), U::size_hint);
let lo = flo.saturating_add(blo);
Expand All @@ -636,7 +637,7 @@ where
}

#[inline]
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
default fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> R,
Expand All @@ -653,7 +654,7 @@ where
}

#[inline]
fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
default fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
where
Fold: FnMut(Acc, Self::Item) -> Acc,
{
Expand All @@ -669,7 +670,7 @@ where

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
default fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
#[inline]
#[rustc_inherit_overflow_checks]
fn advance<U: Iterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
Expand All @@ -686,7 +687,7 @@ where
}

#[inline]
fn count(self) -> usize {
default fn count(self) -> usize {
#[inline]
#[rustc_inherit_overflow_checks]
fn count<U: Iterator>(acc: usize, iter: U) -> usize {
Expand All @@ -697,7 +698,7 @@ where
}

#[inline]
fn last(self) -> Option<Self::Item> {
default fn last(self) -> Option<Self::Item> {
#[inline]
fn last<U: Iterator>(last: Option<U::Item>, iter: U) -> Option<U::Item> {
iter.last().or(last)
Expand All @@ -707,13 +708,14 @@ where
}
}

// See also the `OneShot` specialization below.
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
where
I: DoubleEndedIterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
U: DoubleEndedIterator,
{
#[inline]
fn next_back(&mut self) -> Option<U::Item> {
default fn next_back(&mut self) -> Option<U::Item> {
loop {
if let elt @ Some(_) = and_then_or_clear(&mut self.backiter, |b| b.next_back()) {
return elt;
Expand All @@ -726,7 +728,7 @@ where
}

#[inline]
fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
default fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> R,
Expand All @@ -743,7 +745,7 @@ where
}

#[inline]
fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
default fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
where
Fold: FnMut(Acc, Self::Item) -> Acc,
{
Expand All @@ -759,7 +761,7 @@ where

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
default fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
#[inline]
#[rustc_inherit_overflow_checks]
fn advance<U: DoubleEndedIterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
Expand Down Expand Up @@ -841,3 +843,198 @@ fn and_then_or_clear<T, U>(opt: &mut Option<T>, f: impl FnOnce(&mut T) -> Option
}
x
}

/// Specialization trait for iterator types that never return more than one item.
///
/// Note that we still have to deal with the possibility that the iterator was
/// already exhausted before it came into our control.
#[rustc_specialization_trait]
trait OneShot {}

// These all have exactly one item, if not already consumed.
impl<T> OneShot for Once<T> {}
impl<F> OneShot for OnceWith<F> {}
impl<T> OneShot for array::IntoIter<T, 1> {}
impl<T> OneShot for option::IntoIter<T> {}
impl<T> OneShot for option::Iter<'_, T> {}
impl<T> OneShot for option::IterMut<'_, T> {}
impl<T> OneShot for result::IntoIter<T> {}
impl<T> OneShot for result::Iter<'_, T> {}
impl<T> OneShot for result::IterMut<'_, T> {}

// These are always empty, which is fine to optimize too.
impl<T> OneShot for Empty<T> {}
impl<T> OneShot for array::IntoIter<T, 0> {}

// These adaptors never increase the number of items.
// (There are more possible, but for now this matches BoundedSize above.)
impl<I: OneShot> OneShot for Cloned<I> {}
impl<I: OneShot> OneShot for Copied<I> {}
impl<I: OneShot, P> OneShot for Filter<I, P> {}
impl<I: OneShot, P> OneShot for FilterMap<I, P> {}
impl<I: OneShot, F> OneShot for Map<I, F> {}

// Blanket impls pass this property through as well
// (but we can't do `Box<I>` unless we expose this trait to alloc)
impl<I: OneShot> OneShot for &mut I {}

#[inline]
fn into_item<I>(inner: I) -> Option<I::Item>
where
I: IntoIterator<IntoIter: OneShot>,
{
inner.into_iter().next()
}

#[inline]
fn flatten_one<I: IntoIterator<IntoIter: OneShot>, Acc>(
mut fold: impl FnMut(Acc, I::Item) -> Acc,
) -> impl FnMut(Acc, I) -> Acc {
move |acc, inner| match inner.into_iter().next() {
Some(item) => fold(acc, item),
None => acc,
}
}

#[inline]
fn try_flatten_one<I: IntoIterator<IntoIter: OneShot>, Acc, R: Try<Output = Acc>>(
mut fold: impl FnMut(Acc, I::Item) -> R,
) -> impl FnMut(Acc, I) -> R {
move |acc, inner| match inner.into_iter().next() {
Some(item) => fold(acc, item),
None => try { acc },
}
}

#[inline]
fn advance_by_one<I>(n: NonZero<usize>, inner: I) -> Option<NonZero<usize>>
where
I: IntoIterator<IntoIter: OneShot>,
{
match inner.into_iter().next() {
Some(_) => NonZero::new(n.get() - 1),
None => Some(n),
}
}

// Specialization: When the inner iterator `U` never returns more than one item, the `frontiter` and
// `backiter` states are a waste, because they'll always have already consumed their item. So in
// this impl, we completely ignore them and just focus on `self.iter`, and we only call the inner
// `U::next()` one time.
//
// It's mostly fine if we accidentally mix this with the more generic impls, e.g. by forgetting to
// specialize one of the methods. If the other impl did set the front or back, we wouldn't see it
// here, but it would be empty anyway; and if the other impl looked for a front or back that we
// didn't bother setting, it would just see `None` (or a previous empty) and move on.
//
// An exception to that is `advance_by(0)` and `advance_back_by(0)`, where the generic impls may set
// `frontiter` or `backiter` without consuming the item, so we **must** override those.
impl<I, U> Iterator for FlattenCompat<I, U>
where
I: Iterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
U: Iterator + OneShot,
{
#[inline]
fn next(&mut self) -> Option<U::Item> {
while let Some(inner) = self.iter.next() {
if let item @ Some(_) = inner.into_iter().next() {
return item;
}
}
None
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let (lower, upper) = self.iter.size_hint();
match <I::Item as ConstSizeIntoIterator>::size() {
Some(0) => (0, Some(0)),
Some(1) => (lower, upper),
_ => (0, upper),
}
}

#[inline]
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
self.iter.try_fold(init, try_flatten_one(fold))
}

#[inline]
fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
where
Fold: FnMut(Acc, Self::Item) -> Acc,
{
self.iter.fold(init, flatten_one(fold))
}

#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
if let Some(n) = NonZero::new(n) {
self.iter.try_fold(n, advance_by_one).map_or(Ok(()), Err)
} else {
// Just advance the outer iterator
self.iter.advance_by(0)
}
}

#[inline]
fn count(self) -> usize {
self.iter.filter_map(into_item).count()
}

#[inline]
fn last(self) -> Option<Self::Item> {
self.iter.filter_map(into_item).last()
}
}

// Note: We don't actually care about `U: DoubleEndedIterator`, since forward and backward are the
// same for a one-shot iterator, but we have to keep that to match the default specialization.
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
where
I: DoubleEndedIterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
U: DoubleEndedIterator + OneShot,
{
#[inline]
fn next_back(&mut self) -> Option<U::Item> {
while let Some(inner) = self.iter.next_back() {
if let item @ Some(_) = inner.into_iter().next() {
return item;
}
}
None
}

#[inline]
fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
self.iter.try_rfold(init, try_flatten_one(fold))
}

#[inline]
fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
where
Fold: FnMut(Acc, Self::Item) -> Acc,
{
self.iter.rfold(init, flatten_one(fold))
}

#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
if let Some(n) = NonZero::new(n) {
self.iter.try_rfold(n, advance_by_one).map_or(Ok(()), Err)
} else {
// Just advance the outer iterator
self.iter.advance_back_by(0)
}
}
}
66 changes: 66 additions & 0 deletions library/core/tests/iter/adapters/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,69 @@ fn test_flatten_last() {
assert_eq!(it.advance_by(3), Ok(())); // 22..22
assert_eq!(it.clone().last(), None);
}

#[test]
fn test_flatten_one_shot() {
// This could be `filter_map`, but people often do flatten options.
let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7));
assert_eq!(it.size_hint(), (0, Some(10)));
assert_eq!(it.clone().count(), 8);
assert_eq!(it.clone().last(), NonZero::new(2));

// sum -> fold
let sum: i8 = it.clone().map(|n| n.get()).sum();
assert_eq!(sum, 24);

// the product overflows at 6, remaining are 7,8,9 -> 1,2
let one = NonZero::new(1i8).unwrap();
let product = it.try_fold(one, |acc, x| acc.checked_mul(x));
assert_eq!(product, None);
assert_eq!(it.size_hint(), (0, Some(3)));
assert_eq!(it.clone().count(), 2);

assert_eq!(it.advance_by(0), Ok(()));
assert_eq!(it.clone().next(), NonZero::new(1));
assert_eq!(it.advance_by(1), Ok(()));
assert_eq!(it.clone().next(), NonZero::new(2));
assert_eq!(it.advance_by(100), Err(NonZero::new(99).unwrap()));
assert_eq!(it.next(), None);
}

#[test]
fn test_flatten_one_shot_rev() {
let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7)).rev();
assert_eq!(it.size_hint(), (0, Some(10)));
assert_eq!(it.clone().count(), 8);
assert_eq!(it.clone().last(), NonZero::new(1));

// sum -> Rev fold -> rfold
let sum: i8 = it.clone().map(|n| n.get()).sum();
assert_eq!(sum, 24);

// Rev try_fold -> try_rfold
// the product overflows at 4, remaining are 3,2,1,0 -> 3,2,1
let one = NonZero::new(1i8).unwrap();
let product = it.try_fold(one, |acc, x| acc.checked_mul(x));
assert_eq!(product, None);
assert_eq!(it.size_hint(), (0, Some(4)));
assert_eq!(it.clone().count(), 3);

// Rev advance_by -> advance_back_by
assert_eq!(it.advance_by(0), Ok(()));
assert_eq!(it.clone().next(), NonZero::new(3));
assert_eq!(it.advance_by(1), Ok(()));
assert_eq!(it.clone().next(), NonZero::new(2));
assert_eq!(it.advance_by(100), Err(NonZero::new(98).unwrap()));
assert_eq!(it.next(), None);
}

#[test]
fn test_flatten_one_shot_arrays() {
let it = (0..10).flat_map(|i| [i]);
assert_eq!(it.size_hint(), (10, Some(10)));
assert_eq!(it.sum::<i32>(), 45);

let mut it = (0..10).flat_map(|_| -> [i32; 0] { [] });
assert_eq!(it.size_hint(), (0, Some(0)));
assert_eq!(it.next(), None);
}

0 comments on commit 6672c16

Please sign in to comment.