From 2cd44fb540bc8c24a4d30513af7cd91dedcddb69 Mon Sep 17 00:00:00 2001 From: Ewan Mount Date: Mon, 10 Oct 2022 19:19:22 +0100 Subject: [PATCH] Adding _by, by_key, largest variants of k_smallest --- src/k_smallest.rs | 104 +++++++++++++++++++++++++++++++++++++++------- src/lib.rs | 102 +++++++++++++++++++++++++++++++++++++++++++-- tests/test_std.rs | 51 +++++++++++++++++------ 3 files changed, 226 insertions(+), 31 deletions(-) diff --git a/src/k_smallest.rs b/src/k_smallest.rs index 6af66cfaf..766021b65 100644 --- a/src/k_smallest.rs +++ b/src/k_smallest.rs @@ -1,22 +1,96 @@ -use alloc::collections::BinaryHeap; -use core::cmp::Ord; +use alloc::vec::Vec; +use core::cmp::Ordering; + +/// Consumes a given iterator, returning the minimum elements in **ascending** order. +pub(crate) fn k_smallest_general(mut iter: I, k: usize, mut comparator: F) -> Vec +where + I: Iterator, + F: FnMut(&I::Item, &I::Item) -> Ordering, +{ + /// Sift the element currently at `origin` away from the root until it is properly ordered. + /// + /// This will leave **larger** elements closer to the root of the heap. + fn sift_down(heap: &mut [T], is_less_than: &mut F, mut origin: usize) + where + F: FnMut(&T, &T) -> bool, + { + #[inline] + fn children_of(n: usize) -> (usize, usize) { + (2 * n + 1, 2 * n + 2) + } + + while origin < heap.len() { + let (left_idx, right_idx) = children_of(origin); + if left_idx >= heap.len() { + return; + } + + let replacement_idx = + if right_idx < heap.len() && is_less_than(&heap[left_idx], &heap[right_idx]) { + right_idx + } else { + left_idx + }; + + if is_less_than(&heap[origin], &heap[replacement_idx]) { + heap.swap(origin, replacement_idx); + origin = replacement_idx; + } else { + return; + } + } + } -pub(crate) fn k_smallest>(mut iter: I, k: usize) -> BinaryHeap { if k == 0 { - return BinaryHeap::new(); + return Vec::new(); } + let mut storage: Vec = iter.by_ref().take(k).collect(); - let mut heap = iter.by_ref().take(k).collect::>(); + let mut is_less_than = move |a: &_, b: &_| comparator(a, b) == Ordering::Less; - iter.for_each(|i| { - debug_assert_eq!(heap.len(), k); - // Equivalent to heap.push(min(i, heap.pop())) but more efficient. - // This should be done with a single `.peek_mut().unwrap()` but - // `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior. - if *heap.peek().unwrap() > i { - *heap.peek_mut().unwrap() = i; - } - }); + // Rearrange the storage into a valid heap by reordering from the second-bottom-most layer up to the root. + // Slightly faster than ordering on each insert, but only by a factor of lg(k). + // The resulting heap has the **largest** item on top. + for i in (0..=(storage.len() / 2)).rev() { + sift_down(&mut storage, &mut is_less_than, i); + } + + if k == storage.len() { + // If we fill the storage, there may still be iterator elements left so feed them into the heap. + // Also avoids unexpected behaviour with restartable iterators. + iter.for_each(|val| { + if is_less_than(&val, &storage[0]) { + // Treating this as an push-and-pop saves having to write a sift-up implementation. + // https://en.wikipedia.org/wiki/Binary_heap#Insert_then_extract + storage[0] = val; + // We retain the smallest items we've seen so far, but ordered largest first so we can drop the largest efficiently. + sift_down(&mut storage, &mut is_less_than, 0); + } + }); + } + + // Ultimately the items need to be in least-first, strict order, but the heap is currently largest-first. + // To achieve this, repeatedly, + // 1) "pop" the largest item off the heap into the tail slot of the underlying storage, + // 2) shrink the logical size of the heap by 1, + // 3) restore the heap property over the remaining items. + let mut heap = &mut storage[..]; + while heap.len() > 1 { + let last_idx = heap.len() - 1; + heap.swap(0, last_idx); + // Sifting over a truncated slice means that the sifting will not disturb already popped elements. + heap = &mut heap[..last_idx]; + sift_down(heap, &mut is_less_than, 0); + } + + storage +} - heap +#[inline] +pub(crate) fn key_to_cmp(key: F) -> impl Fn(&T, &T) -> Ordering +where + F: Fn(&T) -> K, + K: Ord, +{ + move |a, b| key(a).cmp(&key(b)) } diff --git a/src/lib.rs b/src/lib.rs index d3dd39cc0..71c8234f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2950,14 +2950,108 @@ pub trait Itertools: Iterator { /// itertools::assert_equal(five_smallest, 0..5); /// ``` #[cfg(feature = "use_alloc")] - fn k_smallest(self, k: usize) -> VecIntoIter + fn k_smallest(mut self, k: usize) -> VecIntoIter where Self: Sized, Self::Item: Ord, { - crate::k_smallest::k_smallest(self, k) - .into_sorted_vec() - .into_iter() + // The stdlib heap has optimised handling of "holes", which is not included in our heap implementation in k_smallest_general. + // While the difference is unlikely to have practical impact unless `Self::Item` is very large, this method uses the stdlib structure + // to maintain performance compared to previous versions of the crate. + use alloc::collections::BinaryHeap; + + if k == 0 { + return Vec::new().into_iter(); + } + + let mut heap = self.by_ref().take(k).collect::>(); + + self.for_each(|i| { + debug_assert_eq!(heap.len(), k); + // Equivalent to heap.push(min(i, heap.pop())) but more efficient. + // This should be done with a single `.peek_mut().unwrap()` but + // `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior. + if *heap.peek().unwrap() > i { + *heap.peek_mut().unwrap() = i; + } + }); + + heap.into_sorted_vec().into_iter() + } + + /// Sort the k smallest elements into a new iterator using the provided comparison. + /// + /// This corresponds to `self.sorted_by(cmp).take(k)` in the same way that + /// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and complexity. + /// Particularly, a custom heap implementation ensures the comparison is not cloned. + #[cfg(feature = "use_alloc")] + fn k_smallest_by(self, k: usize, cmp: F) -> VecIntoIter + where + Self: Sized, + F: Fn(&Self::Item, &Self::Item) -> Ordering, + { + k_smallest::k_smallest_general(self, k, cmp).into_iter() + } + + /// Return the elements producing the k smallest outputs of the provided function + /// + /// This corresponds to `self.sorted_by_key(cmp).take(k)` in the same way that + /// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and time complexity. + #[cfg(feature = "use_alloc")] + fn k_smallest_by_key(self, k: usize, key: F) -> VecIntoIter + where + Self: Sized, + F: Fn(&Self::Item) -> K, + K: Ord, + { + self.k_smallest_by(k, k_smallest::key_to_cmp(key)) + } + + /// Sort the k largest elements into a new iterator, in descending order. + /// Semantically equivalent to `k_smallest` with a reversed `Ord` + /// However, this is implemented by way of a custom binary heap + /// which does not have the same performance characteristics for very large `Self::Item` + /// ``` + /// use itertools::Itertools; + /// + /// // A random permutation of 0..15 + /// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5]; + /// + /// let five_largest = numbers + /// .into_iter() + /// .k_largest(5); + /// + /// itertools::assert_equal(five_largest, vec![14,13,12,11,10]); + /// ``` + #[cfg(feature = "use_alloc")] + fn k_largest(self, k: usize) -> VecIntoIter + where + Self: Sized, + Self::Item: Ord, + { + self.k_largest_by(k, Self::Item::cmp) + } + + /// Sort the k largest elements into a new iterator using the provided comparison. + /// Functionally equivalent to `k_smallest_by` with a reversed `Ord` + #[cfg(feature = "use_alloc")] + fn k_largest_by(self, k: usize, cmp: F) -> VecIntoIter + where + Self: Sized, + F: Fn(&Self::Item, &Self::Item) -> Ordering, + { + self.k_smallest_by(k, move |a, b| cmp(b, a)) + } + + /// Return the elements producing the k largest outputs of the provided function + #[cfg(feature = "use_alloc")] + fn k_largest_by_key(self, k: usize, key: F) -> VecIntoIter + where + Self: Sized, + F: Fn(&Self::Item) -> K, + K: Ord, + { + self.k_largest_by(k, k_smallest::key_to_cmp(key)) } /// Collect all iterator elements into one of two diff --git a/tests/test_std.rs b/tests/test_std.rs index 793018f1c..412986dd0 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -492,23 +492,50 @@ fn sorted_by() { } qc::quickcheck! { - fn k_smallest_range(n: u64, m: u16, k: u16) -> () { + fn k_smallest_range(n: i64, m: u16, k: u16) -> () { // u16 is used to constrain k and m to 0..2¹⁶, // otherwise the test could use too much memory. - let (k, m) = (k as u64, m as u64); + let (k, m) = (k as usize, m as u64); + let mut v: Vec<_> = (n..n.saturating_add(m as _)).collect(); // Generate a random permutation of n..n+m - let i = { - let mut v: Vec = (n..n.saturating_add(m)).collect(); - v.shuffle(&mut thread_rng()); - v.into_iter() - }; + v.shuffle(&mut thread_rng()); + + // Construct the right answers for the top and bottom elements + let mut sorted = v.clone(); + sorted.sort(); + // how many elements are we checking + let num_elements = min(k, m as _); + + // Compute the top and bottom k in various combinations + let smallest = v.iter().cloned().k_smallest(k); + let smallest_by = v.iter().cloned().k_smallest_by(k, Ord::cmp); + let smallest_by_key = v.iter().cloned().k_smallest_by_key(k, |&x| x); + + let largest = v.iter().cloned().k_largest(k); + let largest_by = v.iter().cloned().k_largest_by(k, Ord::cmp); + let largest_by_key = v.iter().cloned().k_largest_by_key(k, |&x| x); + + // Check the variations produce the same answers and that they're right + for (a,b,c,d) in izip!( + sorted[..num_elements].iter().cloned(), + smallest, + smallest_by, + smallest_by_key) { + assert_eq!(a,b); + assert_eq!(a,c); + assert_eq!(a,d); + } - // Check that taking the k smallest elements yields n..n+min(k, m) - it::assert_equal( - i.k_smallest(k as usize), - n..n.saturating_add(min(k, m)) - ); + for (a,b,c,d) in izip!( + sorted[sorted.len()-num_elements..].iter().rev().cloned(), + largest, + largest_by, + largest_by_key) { + assert_eq!(a,b); + assert_eq!(a,c); + assert_eq!(a,d); + } } }