Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding _by, by_key, largest variants of k_smallest #654

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 89 additions & 15 deletions src/k_smallest.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,96 @@
use alloc::collections::BinaryHeap;
use core::cmp::Ord;
use alloc::vec::Vec;
use core::cmp::Ordering;

/// Consumes a given iterator, returning the minimum elements in **ascending** order.
pub(crate) fn k_smallest_general<I, F>(mut iter: I, k: usize, mut comparator: F) -> Vec<I::Item>
where
I: Iterator,
F: FnMut(&I::Item, &I::Item) -> Ordering,
{
/// Sift the element currently at `origin` away from the root until it is properly ordered.
///
/// This will leave **larger** elements closer to the root of the heap.
fn sift_down<T, F>(heap: &mut [T], is_less_than: &mut F, mut origin: usize)
where
F: FnMut(&T, &T) -> bool,
{
#[inline]
fn children_of(n: usize) -> (usize, usize) {
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
(2 * n + 1, 2 * n + 2)
}

while origin < heap.len() {
let (left_idx, right_idx) = children_of(origin);
if left_idx >= heap.len() {
return;
}

let replacement_idx =
if right_idx < heap.len() && is_less_than(&heap[left_idx], &heap[right_idx]) {
right_idx
} else {
left_idx
};

if is_less_than(&heap[origin], &heap[replacement_idx]) {
heap.swap(origin, replacement_idx);
origin = replacement_idx;
} else {
return;
}
}
}

pub(crate) fn k_smallest<T: Ord, I: Iterator<Item = T>>(mut iter: I, k: usize) -> BinaryHeap<T> {
if k == 0 {
return BinaryHeap::new();
return Vec::new();
}
let mut storage: Vec<I::Item> = iter.by_ref().take(k).collect();

let mut heap = iter.by_ref().take(k).collect::<BinaryHeap<_>>();
let mut is_less_than = move |a: &_, b: &_| comparator(a, b) == Ordering::Less;

iter.for_each(|i| {
debug_assert_eq!(heap.len(), k);
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
// This should be done with a single `.peek_mut().unwrap()` but
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
if *heap.peek().unwrap() > i {
*heap.peek_mut().unwrap() = i;
}
});
// Rearrange the storage into a valid heap by reordering from the second-bottom-most layer up to the root.
// Slightly faster than ordering on each insert, but only by a factor of lg(k).
// The resulting heap has the **largest** item on top.
for i in (0..=(storage.len() / 2)).rev() {
sift_down(&mut storage, &mut is_less_than, i);
}

if k == storage.len() {
// If we fill the storage, there may still be iterator elements left so feed them into the heap.
// Also avoids unexpected behaviour with restartable iterators.
iter.for_each(|val| {
if is_less_than(&val, &storage[0]) {
// Treating this as an push-and-pop saves having to write a sift-up implementation.
// https://en.wikipedia.org/wiki/Binary_heap#Insert_then_extract
storage[0] = val;
// We retain the smallest items we've seen so far, but ordered largest first so we can drop the largest efficiently.
sift_down(&mut storage, &mut is_less_than, 0);
}
});
}

// Ultimately the items need to be in least-first, strict order, but the heap is currently largest-first.
// To achieve this, repeatedly,
// 1) "pop" the largest item off the heap into the tail slot of the underlying storage,
// 2) shrink the logical size of the heap by 1,
// 3) restore the heap property over the remaining items.
let mut heap = &mut storage[..];
while heap.len() > 1 {
let last_idx = heap.len() - 1;
heap.swap(0, last_idx);
// Sifting over a truncated slice means that the sifting will not disturb already popped elements.
heap = &mut heap[..last_idx];
sift_down(heap, &mut is_less_than, 0);
}

storage
}

heap
#[inline]
pub(crate) fn key_to_cmp<T, K, F>(key: F) -> impl Fn(&T, &T) -> Ordering
Philippe-Cholet marked this conversation as resolved.
Show resolved Hide resolved
where
F: Fn(&T) -> K,
K: Ord,
{
move |a, b| key(a).cmp(&key(b))
}
102 changes: 98 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2950,14 +2950,108 @@ pub trait Itertools: Iterator {
/// itertools::assert_equal(five_smallest, 0..5);
/// ```
#[cfg(feature = "use_alloc")]
fn k_smallest(self, k: usize) -> VecIntoIter<Self::Item>
fn k_smallest(mut self, k: usize) -> VecIntoIter<Self::Item>
where
Self: Sized,
Self::Item: Ord,
{
crate::k_smallest::k_smallest(self, k)
.into_sorted_vec()
.into_iter()
// The stdlib heap has optimised handling of "holes", which is not included in our heap implementation in k_smallest_general.
// While the difference is unlikely to have practical impact unless `Self::Item` is very large, this method uses the stdlib structure
// to maintain performance compared to previous versions of the crate.
use alloc::collections::BinaryHeap;

if k == 0 {
return Vec::new().into_iter();
}

let mut heap = self.by_ref().take(k).collect::<BinaryHeap<_>>();

self.for_each(|i| {
debug_assert_eq!(heap.len(), k);
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
// This should be done with a single `.peek_mut().unwrap()` but
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
if *heap.peek().unwrap() > i {
*heap.peek_mut().unwrap() = i;
}
});

heap.into_sorted_vec().into_iter()
}

/// Sort the k smallest elements into a new iterator using the provided comparison.
///
/// This corresponds to `self.sorted_by(cmp).take(k)` in the same way that
/// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and complexity.
/// Particularly, a custom heap implementation ensures the comparison is not cloned.
#[cfg(feature = "use_alloc")]
fn k_smallest_by<F>(self, k: usize, cmp: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item, &Self::Item) -> Ordering,
{
k_smallest::k_smallest_general(self, k, cmp).into_iter()
}

/// Return the elements producing the k smallest outputs of the provided function
///
/// This corresponds to `self.sorted_by_key(cmp).take(k)` in the same way that
/// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and time complexity.
#[cfg(feature = "use_alloc")]
fn k_smallest_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item) -> K,
K: Ord,
{
self.k_smallest_by(k, k_smallest::key_to_cmp(key))
}

/// Sort the k largest elements into a new iterator, in descending order.
/// Semantically equivalent to `k_smallest` with a reversed `Ord`
/// However, this is implemented by way of a custom binary heap
/// which does not have the same performance characteristics for very large `Self::Item`
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_largest = numbers
/// .into_iter()
/// .k_largest(5);
///
/// itertools::assert_equal(five_largest, vec![14,13,12,11,10]);
/// ```
#[cfg(feature = "use_alloc")]
fn k_largest(self, k: usize) -> VecIntoIter<Self::Item>
where
Self: Sized,
Self::Item: Ord,
{
self.k_largest_by(k, Self::Item::cmp)
}

/// Sort the k largest elements into a new iterator using the provided comparison.
/// Functionally equivalent to `k_smallest_by` with a reversed `Ord`
#[cfg(feature = "use_alloc")]
fn k_largest_by<F>(self, k: usize, cmp: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item, &Self::Item) -> Ordering,
{
self.k_smallest_by(k, move |a, b| cmp(b, a))
}

/// Return the elements producing the k largest outputs of the provided function
#[cfg(feature = "use_alloc")]
fn k_largest_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item) -> K,
K: Ord,
{
self.k_largest_by(k, k_smallest::key_to_cmp(key))
}

/// Collect all iterator elements into one of two
Expand Down
51 changes: 39 additions & 12 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,23 +492,50 @@ fn sorted_by() {
}

qc::quickcheck! {
fn k_smallest_range(n: u64, m: u16, k: u16) -> () {
fn k_smallest_range(n: i64, m: u16, k: u16) -> () {
// u16 is used to constrain k and m to 0..2¹⁶,
// otherwise the test could use too much memory.
let (k, m) = (k as u64, m as u64);
let (k, m) = (k as usize, m as u64);

let mut v: Vec<_> = (n..n.saturating_add(m as _)).collect();
// Generate a random permutation of n..n+m
let i = {
let mut v: Vec<u64> = (n..n.saturating_add(m)).collect();
v.shuffle(&mut thread_rng());
v.into_iter()
};
v.shuffle(&mut thread_rng());

// Construct the right answers for the top and bottom elements
let mut sorted = v.clone();
sorted.sort();
// how many elements are we checking
let num_elements = min(k, m as _);

// Compute the top and bottom k in various combinations
let smallest = v.iter().cloned().k_smallest(k);
let smallest_by = v.iter().cloned().k_smallest_by(k, Ord::cmp);
let smallest_by_key = v.iter().cloned().k_smallest_by_key(k, |&x| x);

let largest = v.iter().cloned().k_largest(k);
let largest_by = v.iter().cloned().k_largest_by(k, Ord::cmp);
let largest_by_key = v.iter().cloned().k_largest_by_key(k, |&x| x);

// Check the variations produce the same answers and that they're right
for (a,b,c,d) in izip!(
sorted[..num_elements].iter().cloned(),
smallest,
smallest_by,
smallest_by_key) {
assert_eq!(a,b);
assert_eq!(a,c);
assert_eq!(a,d);
}

// Check that taking the k smallest elements yields n..n+min(k, m)
it::assert_equal(
i.k_smallest(k as usize),
n..n.saturating_add(min(k, m))
);
for (a,b,c,d) in izip!(
sorted[sorted.len()-num_elements..].iter().rev().cloned(),
largest,
largest_by,
largest_by_key) {
assert_eq!(a,b);
assert_eq!(a,c);
assert_eq!(a,d);
}
}
}

Expand Down