diff --git a/src/combinations.rs b/src/combinations.rs index a5b34fc95..4b0742a0d 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -77,6 +77,12 @@ impl Combinations { self.pool.prefill(k); } } + + pub(crate) fn n_and_count(self) -> (usize, usize) { + let Self { indices, pool, first } = self; + let n = pool.count(); + (n, remaining_for(n, first, &indices).unwrap()) + } } impl Iterator for Combinations @@ -128,10 +134,9 @@ impl Iterator for Combinations (low, upp) } + #[inline] fn count(self) -> usize { - let Self { indices, pool, first } = self; - let n = pool.count(); - remaining_for(n, first, &indices).unwrap() + self.n_and_count().1 } } diff --git a/src/powerset.rs b/src/powerset.rs index 4175f013b..f84cecfbd 100644 --- a/src/powerset.rs +++ b/src/powerset.rs @@ -3,8 +3,8 @@ use std::iter::FusedIterator; use std::usize; use alloc::vec::Vec; -use super::combinations::{Combinations, combinations}; -use super::size_hint; +use super::combinations::{Combinations, checked_binomial, combinations}; +use crate::size_hint::{self, SizeHint}; /// An iterator to iterate through the powerset of the elements from an iterator. /// @@ -13,22 +13,20 @@ use super::size_hint; #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] pub struct Powerset { combs: Combinations, - // Iterator `position` (equal to count of yielded elements). - pos: usize, } impl Clone for Powerset where I: Clone + Iterator, I::Item: Clone, { - clone_fields!(combs, pos); + clone_fields!(combs); } impl fmt::Debug for Powerset where I: Iterator + fmt::Debug, I::Item: fmt::Debug, { - debug_fmt_fields!(Powerset, combs, pos); + debug_fmt_fields!(Powerset, combs); } /// Create a new `Powerset` from a clonable iterator. @@ -38,7 +36,6 @@ pub fn powerset(src: I) -> Powerset { Powerset { combs: combinations(src, 0), - pos: 0, } } @@ -51,35 +48,30 @@ impl Iterator for Powerset fn next(&mut self) -> Option { if let Some(elt) = self.combs.next() { - self.pos = self.pos.saturating_add(1); Some(elt) } else if self.combs.k() < self.combs.n() || self.combs.k() == 0 { self.combs.reset(self.combs.k() + 1); - self.combs.next().map(|elt| { - self.pos = self.pos.saturating_add(1); - elt - }) + self.combs.next() } else { None } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> SizeHint { + let k = self.combs.k(); // Total bounds for source iterator. - let src_total = self.combs.src().size_hint(); - - // Total bounds for self ( length(powerset(set) == 2 ^ length(set) ) - let self_total = size_hint::pow_scalar_base(2, src_total); + let (n_min, n_max) = self.combs.src().size_hint(); + let low = remaining_for(n_min, k).unwrap_or(usize::MAX); + let upp = n_max.and_then(|n| remaining_for(n, k)); + size_hint::add(self.combs.size_hint(), (low, upp)) + } - if self.pos < usize::MAX { - // Subtract count of elements already yielded from total. - size_hint::sub_scalar(self_total, self.pos) - } else { - // Fallback: self.pos is saturated and no longer reliable. - (0, self_total.1) - } + fn count(self) -> usize { + let k = self.combs.k(); + let (n, combs_count) = self.combs.n_and_count(); + combs_count + remaining_for(n, k).unwrap() } } @@ -88,3 +80,9 @@ impl FusedIterator for Powerset I: Iterator, I::Item: Clone, {} + +fn remaining_for(n: usize, k: usize) -> Option { + (k + 1..=n).fold(Some(0), |sum, i| { + sum.and_then(|s| s.checked_add(checked_binomial(n, i)?)) + }) +} diff --git a/src/size_hint.rs b/src/size_hint.rs index 71ea1412b..f7278aec9 100644 --- a/src/size_hint.rs +++ b/src/size_hint.rs @@ -3,7 +3,6 @@ use std::usize; use std::cmp; -use std::u32; /// `SizeHint` is the return type of `Iterator::size_hint()`. pub type SizeHint = (usize, Option); @@ -75,20 +74,6 @@ pub fn mul_scalar(sh: SizeHint, x: usize) -> SizeHint { (low, hi) } -/// Raise `base` correctly by a `SizeHint` exponent. -#[inline] -pub fn pow_scalar_base(base: usize, exp: SizeHint) -> SizeHint { - let exp_low = cmp::min(exp.0, u32::MAX as usize) as u32; - let low = base.saturating_pow(exp_low); - - let hi = exp.1.and_then(|exp| { - let exp_hi = cmp::min(exp, u32::MAX as usize) as u32; - base.checked_pow(exp_hi) - }); - - (low, hi) -} - /// Return the maximum #[inline] pub fn max(a: SizeHint, b: SizeHint) -> SizeHint { diff --git a/tests/test_std.rs b/tests/test_std.rs index 3d6d66cf1..8ea992183 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -909,15 +909,19 @@ fn combinations_zero() { it::assert_equal((0..0).combinations(0), vec![vec![]]); } +fn binomial(n: usize, k: usize) -> usize { + if k > n { + 0 + } else { + (n - k + 1..=n).product::() / (1..=k).product::() + } +} + #[test] fn combinations_range_count() { for n in 0..=10 { for k in 0..=10 { - let len = if k<=n { - (n - k + 1..=n).product::() / (1..=k).product::() - } else { - 0 - }; + let len = binomial(n, k); let mut it = (0..n).combinations(k); assert_eq!(len, it.clone().count()); assert_eq!(len, it.size_hint().0); @@ -935,6 +939,47 @@ fn combinations_range_count() { } } +#[test] +fn combinations_inexact_size_hints() { + for k in 0..=10 { + let mut numbers = (0..18).filter(|i| i % 2 == 0); // 9 elements + let mut it = numbers.clone().combinations(k); + let real_n = numbers.clone().count(); + let len = binomial(real_n, k); + assert_eq!(len, it.clone().count()); + + let mut nb_loaded = numbers.by_ref().take(k).count(); // because of `LazyBuffer::prefill(k)` + let sh = numbers.size_hint(); + assert_eq!(binomial(sh.0 + nb_loaded, k), it.size_hint().0); + assert_eq!(sh.1.map(|n| binomial(n + nb_loaded, k)), it.size_hint().1); + + for next_count in 1..=len { + let elem = it.next(); + assert!(elem.is_some()); + assert_eq!(len - next_count, it.clone().count()); + // It does not load anything more the very first time (it's prefilled). + if next_count > 1 { + // Then it loads one item each time until exhausted. + let nb = numbers.next(); + if nb.is_some() { + nb_loaded += 1; + } + } + let sh = numbers.size_hint(); + if next_count > real_n - k + 1 { + assert_eq!(0, sh.0); + assert_eq!(Some(0), sh.1); + assert_eq!(real_n, nb_loaded); + // Once it's fully loaded, size hints of `it` are exacts. + } + assert_eq!(binomial(sh.0 + nb_loaded, k) - next_count, it.size_hint().0); + assert_eq!(sh.1.map(|n| binomial(n + nb_loaded, k) - next_count), it.size_hint().1); + } + let should_be_none = it.next(); + assert!(should_be_none.is_none()); + } +} + #[test] fn permutations_zero() { it::assert_equal((1..3).permutations(0), vec![vec![]]); @@ -989,6 +1034,23 @@ fn powerset() { assert_eq!((0..4).powerset().count(), 1 << 4); assert_eq!((0..8).powerset().count(), 1 << 8); assert_eq!((0..16).powerset().count(), 1 << 16); + + for n in 0..=10 { + let mut it = (0..n).powerset(); + let len = 2_usize.pow(n); + assert_eq!(len, it.clone().count()); + assert_eq!(len, it.size_hint().0); + assert_eq!(Some(len), it.size_hint().1); + for count in (0..len).rev() { + let elem = it.next(); + assert!(elem.is_some()); + assert_eq!(count, it.clone().count()); + assert_eq!(count, it.size_hint().0); + assert_eq!(Some(count), it.size_hint().1); + } + let should_be_none = it.next(); + assert!(should_be_none.is_none()); + } } #[test]