Skip to content

Commit

Permalink
BTreeMap: split off most code of append, slightly improve interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
ssomers committed Nov 8, 2020
1 parent b1277d0 commit 685fd53
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 114 deletions.
124 changes: 124 additions & 0 deletions library/alloc/src/collections/btree/append.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use super::map::MIN_LEN;
use super::merge_iter::MergeIterInner;
use super::node::{self, ForceResult::*, Root};
use core::iter::FusedIterator;

impl<K, V> Root<K, V> {
/// Appends all key-value pairs from the union of two ascending iterators,
/// incrementing a `length` variable along the way. The latter makes it
/// easier for the caller to avoid a leak when a drop handler panicks.
///
/// If both iterators produce the same key, this method drops the pair from
/// the left iterator and appends the pair from the right iterator.
///
/// If you want the tree to end up in a strictly ascending order, like for
/// a `BTreeMap`, both iterators should produce keys in strictly ascending
/// order, each greater than all keys in the tree, including any keys
/// already in the tree upon entry.
pub fn append_from_sorted_iters<I>(&mut self, left: I, right: I, length: &mut usize)
where
K: Ord,
I: Iterator<Item = (K, V)> + FusedIterator,
{
// We prepare to merge `left` and `right` into a sorted sequence in linear time.
let iter = MergeIter(MergeIterInner::new(left, right));

// Meanwhile, we build a tree from the sorted sequence in linear time.
self.bulk_push(iter, length)
}

/// Pushes all key-value pairs to the end of the tree, incrementing a
/// `length` variable along the way. The latter makes it easier for the
/// caller to avoid a leak when the iterator panicks.
fn bulk_push<I>(&mut self, iter: I, length: &mut usize)
where
I: Iterator<Item = (K, V)>,
{
let mut cur_node = self.node_as_mut().last_leaf_edge().into_node();
// Iterate through all key-value pairs, pushing them into nodes at the right level.
for (key, value) in iter {
// Try to push key-value pair into the current leaf node.
if cur_node.len() < node::CAPACITY {
cur_node.push(key, value);
} else {
// No space left, go up and push there.
let mut open_node;
let mut test_node = cur_node.forget_type();
loop {
match test_node.ascend() {
Ok(parent) => {
let parent = parent.into_node();
if parent.len() < node::CAPACITY {
// Found a node with space left, push here.
open_node = parent;
break;
} else {
// Go up again.
test_node = parent.forget_type();
}
}
Err(_) => {
// We are at the top, create a new root node and push there.
open_node = self.push_internal_level();
break;
}
}
}

// Push key-value pair and new right subtree.
let tree_height = open_node.height() - 1;
let mut right_tree = Root::new_leaf();
for _ in 0..tree_height {
right_tree.push_internal_level();
}
open_node.push(key, value, right_tree);

// Go down to the right-most leaf again.
cur_node = open_node.forget_type().last_leaf_edge().into_node();
}

// Increment length every iteration, to make sure the map drops
// the appended elements even if advancing the iterator panicks.
*length += 1;
}
self.fix_right_edge();
}

fn fix_right_edge(&mut self) {
// Handle underfull nodes, start from the top.
let mut cur_node = self.node_as_mut();
while let Internal(internal) = cur_node.force() {
// Check if right-most child is underfull.
let mut last_edge = internal.last_edge();
let right_child_len = last_edge.reborrow().descend().len();
if right_child_len < MIN_LEN {
// We need to steal.
let mut last_kv = match last_edge.left_kv() {
Ok(left) => left,
Err(_) => unreachable!(),
};
last_kv.bulk_steal_left(MIN_LEN - right_child_len);
last_edge = last_kv.right_edge();
}

// Go further down.
cur_node = last_edge.descend();
}
}
}

// An iterator for merging two sorted sequences into one
struct MergeIter<K, V, I: Iterator<Item = (K, V)>>(MergeIterInner<I>);

impl<K: Ord, V, I> Iterator for MergeIter<K, V, I>
where
I: Iterator<Item = (K, V)> + FusedIterator,
{
type Item = (K, V);

/// If two keys are equal, returns the key-value pair from the right source.
fn next(&mut self) -> Option<(K, V)> {
let (a_next, b_next) = self.0.nexts(|a: &(K, V), b: &(K, V)| K::cmp(&a.0, &b.0));
b_next.or(a_next)
}
}
96 changes: 2 additions & 94 deletions library/alloc/src/collections/btree/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use core::ops::{Index, RangeBounds};
use core::ptr;

use super::borrow::DormantMutRef;
use super::merge_iter::MergeIterInner;
use super::node::{self, marker, ForceResult::*, Handle, NodeRef};
use super::search::{self, SearchResult::*};
use super::unwrap_unchecked;
Expand Down Expand Up @@ -458,9 +457,6 @@ impl<K: fmt::Debug, V: fmt::Debug> fmt::Debug for RangeMut<'_, K, V> {
}
}

// An iterator for merging two sorted sequences into one
struct MergeIter<K, V, I: Iterator<Item = (K, V)>>(MergeIterInner<I>);

impl<K: Ord, V> BTreeMap<K, V> {
/// Makes a new empty BTreeMap.
///
Expand Down Expand Up @@ -908,13 +904,10 @@ impl<K: Ord, V> BTreeMap<K, V> {
return;
}

// First, we merge `self` and `other` into a sorted sequence in linear time.
let self_iter = mem::take(self).into_iter();
let other_iter = mem::take(other).into_iter();
let iter = MergeIter(MergeIterInner::new(self_iter, other_iter));

// Second, we build a tree from the sorted sequence in linear time.
self.from_sorted_iter(iter);
let root = BTreeMap::ensure_is_owned(&mut self.root);
root.append_from_sorted_iters(self_iter, other_iter, &mut self.length)
}

/// Constructs a double-ended iterator over a sub-range of elements in the map.
Expand Down Expand Up @@ -1039,78 +1032,6 @@ impl<K: Ord, V> BTreeMap<K, V> {
}
}

fn from_sorted_iter<I: Iterator<Item = (K, V)>>(&mut self, iter: I) {
let root = Self::ensure_is_owned(&mut self.root);
let mut cur_node = root.node_as_mut().last_leaf_edge().into_node();
// Iterate through all key-value pairs, pushing them into nodes at the right level.
for (key, value) in iter {
// Try to push key-value pair into the current leaf node.
if cur_node.len() < node::CAPACITY {
cur_node.push(key, value);
} else {
// No space left, go up and push there.
let mut open_node;
let mut test_node = cur_node.forget_type();
loop {
match test_node.ascend() {
Ok(parent) => {
let parent = parent.into_node();
if parent.len() < node::CAPACITY {
// Found a node with space left, push here.
open_node = parent;
break;
} else {
// Go up again.
test_node = parent.forget_type();
}
}
Err(_) => {
// We are at the top, create a new root node and push there.
open_node = root.push_internal_level();
break;
}
}
}

// Push key-value pair and new right subtree.
let tree_height = open_node.height() - 1;
let mut right_tree = node::Root::new_leaf();
for _ in 0..tree_height {
right_tree.push_internal_level();
}
open_node.push(key, value, right_tree);

// Go down to the right-most leaf again.
cur_node = open_node.forget_type().last_leaf_edge().into_node();
}

self.length += 1;
}
Self::fix_right_edge(root)
}

fn fix_right_edge(root: &mut node::Root<K, V>) {
// Handle underfull nodes, start from the top.
let mut cur_node = root.node_as_mut();
while let Internal(internal) = cur_node.force() {
// Check if right-most child is underfull.
let mut last_edge = internal.last_edge();
let right_child_len = last_edge.reborrow().descend().len();
if right_child_len < MIN_LEN {
// We need to steal.
let mut last_kv = match last_edge.left_kv() {
Ok(left) => left,
Err(_) => unreachable!(),
};
last_kv.bulk_steal_left(MIN_LEN - right_child_len);
last_edge = last_kv.right_edge();
}

// Go further down.
cur_node = last_edge.descend();
}
}

/// Splits the collection into two at the given key. Returns everything after the given key,
/// including the key.
///
Expand Down Expand Up @@ -2220,18 +2141,5 @@ impl<K, V> BTreeMap<K, V> {
}
}

impl<K: Ord, V, I> Iterator for MergeIter<K, V, I>
where
I: Iterator<Item = (K, V)> + ExactSizeIterator + FusedIterator,
{
type Item = (K, V);

/// If two keys are equal, returns the key/value-pair from the right source.
fn next(&mut self) -> Option<(K, V)> {
let (a_next, b_next) = self.0.nexts(|a: &(K, V), b: &(K, V)| K::cmp(&a.0, &b.0));
b_next.or(a_next)
}
}

#[cfg(test)]
mod tests;
27 changes: 27 additions & 0 deletions library/alloc/src/collections/btree/map/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,33 @@ create_append_test!(test_append_239, 239);
#[cfg(not(miri))] // Miri is too slow
create_append_test!(test_append_1700, 1700);

#[test]
fn test_append_drop_leak() {
static DROPS: AtomicUsize = AtomicUsize::new(0);

struct D;

impl Drop for D {
fn drop(&mut self) {
if DROPS.fetch_add(1, Ordering::SeqCst) == 0 {
panic!("panic in `drop`");
}
}
}

let mut left = BTreeMap::new();
let mut right = BTreeMap::new();
left.insert(0, D);
left.insert(1, D); // first to be dropped during append
left.insert(2, D);
right.insert(1, D);
right.insert(2, D);

catch_unwind(move || left.append(&mut right)).unwrap_err();

assert_eq!(DROPS.load(Ordering::SeqCst), 4); // Rust issue #47949 ate one little piggy
}

fn rand_data(len: usize) -> Vec<(u32, u32)> {
assert!(len * 2 <= 70029); // from that point on numbers repeat
let mut rng = DeterministicRng::new();
Expand Down
42 changes: 22 additions & 20 deletions library/alloc/src/collections/btree/merge_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,43 @@ use core::cmp::Ordering;
use core::fmt::{self, Debug};
use core::iter::FusedIterator;

/// Core of an iterator that merges the output of two ascending iterators,
/// Core of an iterator that merges the output of two strictly ascending iterators,
/// for instance a union or a symmetric difference.
pub struct MergeIterInner<I>
where
I: Iterator,
{
pub struct MergeIterInner<I: Iterator> {
a: I,
b: I,
peeked: Option<Peeked<I>>,
}

/// Benchmarks faster than wrapping both iterators in a Peekable.
/// Benchmarks faster than wrapping both iterators in a Peekable,
/// probably because we can afford to impose a FusedIterator bound.
#[derive(Clone, Debug)]
enum Peeked<I: Iterator> {
A(I::Item),
B(I::Item),
}

impl<I> Clone for MergeIterInner<I>
impl<I: Iterator> Clone for MergeIterInner<I>
where
I: Clone + Iterator,
I: Clone,
I::Item: Clone,
{
fn clone(&self) -> Self {
Self { a: self.a.clone(), b: self.b.clone(), peeked: self.peeked.clone() }
}
}

impl<I> Debug for MergeIterInner<I>
impl<I: Iterator> Debug for MergeIterInner<I>
where
I: Iterator + Debug,
I: Debug,
I::Item: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MergeIterInner").field(&self.a).field(&self.b).finish()
f.debug_tuple("MergeIterInner").field(&self.a).field(&self.b).field(&self.peeked).finish()
}
}

impl<I> MergeIterInner<I>
where
I: ExactSizeIterator + FusedIterator,
{
impl<I: Iterator> MergeIterInner<I> {
/// Creates a new core for an iterator merging a pair of sources.
pub fn new(a: I, b: I) -> Self {
MergeIterInner { a, b, peeked: None }
Expand All @@ -52,13 +47,17 @@ where
/// Returns the next pair of items stemming from the pair of sources
/// being merged. If both returned options contain a value, that value
/// is equal and occurs in both sources. If one of the returned options
/// contains a value, that value doesn't occur in the other source.
/// If neither returned option contains a value, iteration has finished
/// and subsequent calls will return the same empty pair.
/// contains a value, that value doesn't occur in the other source (or
/// the sources are not strictly ascending). If neither returned option
/// contains a value, iteration has finished and subsequent calls will
/// return the same empty pair.
pub fn nexts<Cmp: Fn(&I::Item, &I::Item) -> Ordering>(
&mut self,
cmp: Cmp,
) -> (Option<I::Item>, Option<I::Item>) {
) -> (Option<I::Item>, Option<I::Item>)
where
I: FusedIterator,
{
let mut a_next;
let mut b_next;
match self.peeked.take() {
Expand Down Expand Up @@ -86,7 +85,10 @@ where
}

/// Returns a pair of upper bounds for the `size_hint` of the final iterator.
pub fn lens(&self) -> (usize, usize) {
pub fn lens(&self) -> (usize, usize)
where
I: ExactSizeIterator,
{
match self.peeked {
Some(Peeked::A(_)) => (1 + self.a.len(), self.b.len()),
Some(Peeked::B(_)) => (self.a.len(), 1 + self.b.len()),
Expand Down
1 change: 1 addition & 0 deletions library/alloc/src/collections/btree/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod append;
mod borrow;
pub mod map;
mod mem;
Expand Down

0 comments on commit 685fd53

Please sign in to comment.