Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More size_hint and count methods #725

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 46 additions & 2 deletions src/combinations.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::fmt;
use std::iter::FusedIterator;
use std::iter::{Fuse, FusedIterator};

use super::lazy_buffer::LazyBuffer;
use super::size_hint::{self, SizeHint};
use alloc::vec::Vec;

/// An iterator to iterate through all the `k`-length combinations in an iterator.
Expand Down Expand Up @@ -52,9 +53,15 @@ impl<I: Iterator> Combinations<I> {
#[inline]
pub fn n(&self) -> usize { self.pool.len() }

/// Fill the pool to get its length.
pub(crate) fn real_n(&mut self) -> usize {
while self.pool.get_next() {}
self.pool.len()
}

/// Returns a reference to the source iterator.
#[inline]
pub(crate) fn src(&self) -> &I { &self.pool.it }
pub(crate) fn src(&self) -> &Fuse<I> { &self.pool.it }

/// Resets this `Combinations` back to an initial state for combinations of length
/// `k` over the same pool data source. If `k` is larger than the current length
Expand All @@ -77,6 +84,20 @@ impl<I: Iterator> Combinations<I> {
self.pool.prefill(k);
}
}

fn remaining_for(&self, n: usize) -> Option<usize> {
let k = self.k();
if self.first {
binomial(n, k)
} else {
self.indices
.iter()
.enumerate()
.fold(Some(0), |sum, (k0, n0)| {
sum.and_then(|s| s.checked_add(binomial(n - 1 - *n0, k - k0)?))
})
}
}
}

impl<I> Iterator for Combinations<I>
Expand Down Expand Up @@ -120,9 +141,32 @@ impl<I> Iterator for Combinations<I>
// Create result vector based on the indices
Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect())
}

fn size_hint(&self) -> SizeHint {
size_hint::try_map(self.pool.size_hint(), |n| self.remaining_for(n))
}

fn count(mut self) -> usize {
let n = self.real_n();
self.remaining_for(n).expect("Iterator count greater than usize::MAX")
}
}

impl<I> FusedIterator for Combinations<I>
where I: Iterator,
I::Item: Clone
{}

pub(crate) fn binomial(mut n: usize, mut k: usize) -> Option<usize> {
if n < k {
return Some(0);
}
// n! / (n - k)! / k! but trying to avoid it overflows:
k = (n - k).min(k);
let mut c = 1;
for i in 1..=k {
c = (c / i).checked_mul(n)? + c % i * n / i;
n -= 1;
}
Some(c)
}
27 changes: 27 additions & 0 deletions src/combinations_with_replacement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use alloc::vec::Vec;
use std::fmt;
use std::iter::FusedIterator;

use super::combinations::binomial;
use super::lazy_buffer::LazyBuffer;
use super::size_hint::{self, SizeHint};

/// An iterator to iterate through all the `n`-length combinations in an iterator, with replacement.
///
Expand Down Expand Up @@ -36,6 +38,21 @@ where
fn current(&self) -> Vec<I::Item> {
self.indices.iter().map(|i| self.pool[*i].clone()).collect()
}

fn remaining_for(&self, n: usize) -> Option<usize> {
let k_perms = |n: usize, k: usize| binomial((n + k).saturating_sub(1), k);
let k = self.indices.len();
if self.first {
k_perms(n, k)
} else {
self.indices
.iter()
.enumerate()
.fold(Some(0), |sum, (k0, n0)| {
sum.and_then(|s| s.checked_add(k_perms(n - 1 - *n0, k - k0)?))
})
}
}
}

/// Create a new `CombinationsWithReplacement` from a clonable iterator.
Expand Down Expand Up @@ -100,6 +117,16 @@ where
None => None,
}
}

fn size_hint(&self) -> SizeHint {
size_hint::try_map(self.pool.size_hint(), |n| self.remaining_for(n))
}

fn count(mut self) -> usize {
while self.pool.get_next() {}
let n = self.pool.len();
self.remaining_for(n).expect("Iterator count greater than usize::MAX")
}
}

impl<I> FusedIterator for CombinationsWithReplacement<I>
Expand Down
22 changes: 10 additions & 12 deletions src/lazy_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::iter::Fuse;
use std::ops::Index;
use alloc::vec::Vec;

use crate::size_hint::{self, SizeHint};

#[derive(Debug, Clone)]
pub struct LazyBuffer<I: Iterator> {
pub it: I,
done: bool,
pub it: Fuse<I>,
buffer: Vec<I::Item>,
}

Expand All @@ -14,8 +16,7 @@ where
{
pub fn new(it: I) -> LazyBuffer<I> {
LazyBuffer {
it,
done: false,
it: it.fuse(),
buffer: Vec::new(),
}
}
Expand All @@ -24,27 +25,24 @@ where
self.buffer.len()
}

pub fn size_hint(&self) -> SizeHint {
size_hint::add_scalar(self.it.size_hint(), self.len())
}

pub fn get_next(&mut self) -> bool {
if self.done {
return false;
}
if let Some(x) = self.it.next() {
self.buffer.push(x);
true
} else {
self.done = true;
false
}
}

pub fn prefill(&mut self, len: usize) {
let buffer_len = self.buffer.len();

if !self.done && len > buffer_len {
if len > buffer_len {
let delta = len - buffer_len;

self.buffer.extend(self.it.by_ref().take(delta));
self.done = self.buffer.len() < len;
}
}
}
Expand Down
102 changes: 41 additions & 61 deletions src/permutations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt;
use std::iter::once;

use super::lazy_buffer::LazyBuffer;
use super::size_hint::{self, SizeHint};

/// An iterator adaptor that iterates through all the `k`-permutations of the
/// elements from an iterator.
Expand Down Expand Up @@ -47,11 +48,6 @@ enum CompleteState {
}
}

enum CompleteStateRemaining {
Known(usize),
Overflow,
}

impl<I> fmt::Debug for Permutations<I>
where I: Iterator + fmt::Debug,
I::Item: fmt::Debug,
Expand All @@ -72,14 +68,8 @@ pub fn permutations<I: Iterator>(iter: I, k: usize) -> Permutations<I> {
};
}

let mut enough_vals = true;

while vals.len() < k {
if !vals.get_next() {
enough_vals = false;
break;
}
}
vals.prefill(k);
let enough_vals = vals.len() == k;

let state = if enough_vals {
PermutationState::StartUnknownLen { k }
Expand Down Expand Up @@ -122,42 +112,42 @@ where
}

fn count(self) -> usize {
fn from_complete(complete_state: CompleteState) -> usize {
match complete_state.remaining() {
CompleteStateRemaining::Known(count) => count,
CompleteStateRemaining::Overflow => {
panic!("Iterator count greater than usize::MAX");
}
}
}

let Permutations { vals, state } = self;
match state {
PermutationState::StartUnknownLen { k } => {
let n = vals.len() + vals.it.count();
let complete_state = CompleteState::Start { n, k };

from_complete(complete_state)
CompleteState::Start { n, k }.count()
}
PermutationState::OngoingUnknownLen { k, min_n } => {
let prev_iteration_count = min_n - k + 1;
let n = vals.len() + vals.it.count();
let complete_state = CompleteState::Start { n, k };

from_complete(complete_state) - prev_iteration_count
CompleteState::Start { n, k }.count() - prev_iteration_count
},
PermutationState::Complete(state) => from_complete(state),
PermutationState::Complete(state) => state.count(),
PermutationState::Empty => 0
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
fn size_hint(&self) -> SizeHint {
match self.state {
PermutationState::StartUnknownLen { .. } |
PermutationState::OngoingUnknownLen { .. } => (0, None), // TODO can we improve this lower bound?
// Note: the product for `CompleteState::Start` in `remaining` increases with `n`.
PermutationState::StartUnknownLen { k } => {
size_hint::try_map(
self.vals.size_hint(),
|n| CompleteState::Start { n, k }.remaining(),
)
}
PermutationState::OngoingUnknownLen { k, min_n } => {
let prev_iteration_count = min_n - k + 1;
size_hint::try_map(self.vals.size_hint(), |n| {
CompleteState::Start { n, k }
.remaining()
.and_then(|count| count.checked_sub(prev_iteration_count))
})
}
PermutationState::Complete(ref state) => match state.remaining() {
CompleteStateRemaining::Known(count) => (count, Some(count)),
CompleteStateRemaining::Overflow => (::std::usize::MAX, None)
Some(count) => (count, Some(count)),
None => (::std::usize::MAX, None)
}
PermutationState::Empty => (0, Some(0))
}
Expand Down Expand Up @@ -185,7 +175,7 @@ where
let mut complete_state = CompleteState::Start { n, k };

// Advance the complete-state iterator to the correct point
for _ in 0..(prev_iteration_count + 1) {
for _ in 0..=prev_iteration_count {
complete_state.advance();
}

Expand Down Expand Up @@ -238,40 +228,30 @@ impl CompleteState {
}
}

fn remaining(&self) -> CompleteStateRemaining {
use self::CompleteStateRemaining::{Known, Overflow};

/// The remaining count of elements, if it does not overflow.
fn remaining(&self) -> Option<usize> {
match *self {
CompleteState::Start { n, k } => {
if n < k {
return Known(0);
return Some(0);
}

let count: Option<usize> = (n - k + 1..n + 1).fold(Some(1), |acc, i| {
(n - k + 1..n + 1).fold(Some(1), |acc, i| {
acc.and_then(|acc| acc.checked_mul(i))
});

match count {
Some(count) => Known(count),
None => Overflow
}
})
}
CompleteState::Ongoing { ref indices, ref cycles } => {
let mut count: usize = 0;

for (i, &c) in cycles.iter().enumerate() {
let radix = indices.len() - i;
let next_count = count.checked_mul(radix)
.and_then(|count| count.checked_add(c));

count = match next_count {
Some(count) => count,
None => { return Overflow; }
};
}

Known(count)
cycles.iter().enumerate().fold(Some(0), |acc, (i, c)| {
acc.and_then(|count| {
let radix = indices.len() - i;
count.checked_mul(radix)?.checked_add(*c)
})
})
}
}
}

/// The remaining count of elements, panics if it overflows.
fn count(&self) -> usize {
self.remaining().expect("Iterator count greater than usize::MAX")
}
}
9 changes: 8 additions & 1 deletion src/powerset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::iter::FusedIterator;
use std::usize;
use alloc::vec::Vec;

use super::combinations::{Combinations, combinations};
use super::combinations::{Combinations, binomial, combinations};
use super::size_hint;

/// An iterator to iterate through the powerset of the elements from an iterator.
Expand Down Expand Up @@ -81,6 +81,13 @@ impl<I> Iterator for Powerset<I>
(0, self_total.1)
}
}

fn count(mut self) -> usize {
let k = self.combs.k();
let n = self.combs.real_n();
// It could be `(1 << n) - self.pos` but `1 << n` might overflow.
self.combs.count() + (k + 1..=n).map(|k| binomial(n, k).unwrap()).sum::<usize>()
}
}

impl<I> FusedIterator for Powerset<I>
Expand Down
14 changes: 14 additions & 0 deletions src/size_hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,17 @@ pub fn min(a: SizeHint, b: SizeHint) -> SizeHint {
};
(lower, upper)
}

/// Try to apply a function `f` on both bounds of a `SizeHint`, failure means overflow.
///
/// For the resulting size hint to be correct, `f` must be increasing.
#[inline]
pub fn try_map<F>(sh: SizeHint, mut f: F) -> SizeHint
where
F: FnMut(usize) -> Option<usize>,
{
let (mut low, mut hi) = sh;
low = f(low).unwrap_or(usize::MAX);
hi = hi.and_then(f);
(low, hi)
}