Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leak amplification for peek_mut() to ensure BinaryHeap's invariant is always met #105851

Merged
merged 3 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 55 additions & 10 deletions library/alloc/src/collections/binary_heap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
use core::fmt;
use core::iter::{FromIterator, FusedIterator, InPlaceIterable, SourceIter, TrustedLen};
use core::mem::{self, swap, ManuallyDrop};
use core::num::NonZeroUsize;
use core::ops::{Deref, DerefMut};
use core::ptr;

Expand All @@ -165,12 +166,20 @@ mod tests;
/// It is a logic error for an item to be modified in such a way that the
/// item's ordering relative to any other item, as determined by the [`Ord`]
/// trait, changes while it is in the heap. This is normally only possible
/// through [`Cell`], [`RefCell`], global state, I/O, or unsafe code. The
/// through interior mutability, global state, I/O, or unsafe code. The
/// behavior resulting from such a logic error is not specified, but will
/// be encapsulated to the `BinaryHeap` that observed the logic error and not
/// result in undefined behavior. This could include panics, incorrect results,
/// aborts, memory leaks, and non-termination.
///
/// As long as no elements change their relative order while being in the heap
/// as described above, the API of `BinaryHeap` guarantees that the heap
/// invariant remains intact i.e. its methods all behave as documented. For
/// example if a method is documented as iterating in sorted order, that's
/// guaranteed to work as long as elements in the heap have not changed order,
/// even in the presence of closures getting unwinded out of, iterators getting
/// leaked, and similar foolishness.
///
/// # Examples
///
/// ```
Expand Down Expand Up @@ -279,7 +288,9 @@ pub struct BinaryHeap<T> {
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
pub struct PeekMut<'a, T: 'a + Ord> {
heap: &'a mut BinaryHeap<T>,
sift: bool,
// If a set_len + sift_down are required, this is Some. If a &mut T has not
// yet been exposed to peek_mut()'s caller, it's None.
original_len: Option<NonZeroUsize>,
}

#[stable(feature = "collection_debug", since = "1.17.0")]
Expand All @@ -292,7 +303,14 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
impl<T: Ord> Drop for PeekMut<'_, T> {
fn drop(&mut self) {
if self.sift {
if let Some(original_len) = self.original_len {
// SAFETY: That's how many elements were in the Vec at the time of
// the PeekMut::deref_mut call, and therefore also at the time of
// the BinaryHeap::peek_mut call. Since the PeekMut did not end up
// getting leaked, we are now undoing the leak amplification that
// the DerefMut prepared for.
unsafe { self.heap.data.set_len(original_len.get()) };

// SAFETY: PeekMut is only instantiated for non-empty heaps.
unsafe { self.heap.sift_down(0) };
}
Expand All @@ -313,7 +331,26 @@ impl<T: Ord> Deref for PeekMut<'_, T> {
impl<T: Ord> DerefMut for PeekMut<'_, T> {
fn deref_mut(&mut self) -> &mut T {
debug_assert!(!self.heap.is_empty());
self.sift = true;

let len = self.heap.len();
if len > 1 {
// Here we preemptively leak all the rest of the underlying vector
// after the currently max element. If the caller mutates the &mut T
// we're about to give them, and then leaks the PeekMut, all these
// elements will remain leaked. If they don't leak the PeekMut, then
// either Drop or PeekMut::pop will un-leak the vector elements.
//
// This is technique is described throughout several other places in
// the standard library as "leak amplification".
unsafe {
// SAFETY: len > 1 so len != 0.
self.original_len = Some(NonZeroUsize::new_unchecked(len));
// SAFETY: len > 1 so all this does for now is leak elements,
// which is safe.
self.heap.data.set_len(1);
}
}

// SAFE: PeekMut is only instantiated for non-empty heaps
unsafe { self.heap.data.get_unchecked_mut(0) }
}
Expand All @@ -323,9 +360,16 @@ impl<'a, T: Ord> PeekMut<'a, T> {
/// Removes the peeked value from the heap and returns it.
#[stable(feature = "binary_heap_peek_mut_pop", since = "1.18.0")]
pub fn pop(mut this: PeekMut<'a, T>) -> T {
let value = this.heap.pop().unwrap();
this.sift = false;
value
if let Some(original_len) = this.original_len.take() {
// SAFETY: This is how many elements were in the Vec at the time of
// the BinaryHeap::peek_mut call.
unsafe { this.heap.data.set_len(original_len.get()) };

// Unlike in Drop, here we don't also need to do a sift_down even if
// the caller could've mutated the element. It is removed from the
// heap on the next line and pop() is not sensitive to its value.
}
this.heap.pop().unwrap()
}
}

Expand Down Expand Up @@ -398,8 +442,9 @@ impl<T: Ord> BinaryHeap<T> {
/// Returns a mutable reference to the greatest item in the binary heap, or
/// `None` if it is empty.
///
/// Note: If the `PeekMut` value is leaked, the heap may be in an
/// inconsistent state.
/// Note: If the `PeekMut` value is leaked, some heap elements might get
/// leaked along with it, but the remaining elements will remain a valid
/// heap.
///
/// # Examples
///
Expand All @@ -426,7 +471,7 @@ impl<T: Ord> BinaryHeap<T> {
/// otherwise it's *O*(1).
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T>> {
if self.is_empty() { None } else { Some(PeekMut { heap: self, sift: false }) }
if self.is_empty() { None } else { Some(PeekMut { heap: self, original_len: None }) }
}

/// Removes the greatest item from the binary heap and returns it, or `None` if it
Expand Down
19 changes: 19 additions & 0 deletions library/alloc/src/collections/binary_heap/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::*;
use crate::boxed::Box;
use crate::testing::crash_test::{CrashTestDummy, Panic};
use core::mem;
use std::iter::TrustedLen;
use std::panic::{catch_unwind, AssertUnwindSafe};

Expand Down Expand Up @@ -146,6 +147,24 @@ fn test_peek_mut() {
assert_eq!(heap.peek(), Some(&9));
}

#[test]
fn test_peek_mut_leek() {
let data = vec![4, 2, 7];
let mut heap = BinaryHeap::from(data);
let mut max = heap.peek_mut().unwrap();
*max = -1;

// The PeekMut object's Drop impl would have been responsible for moving the
// -1 out of the max position of the BinaryHeap, but we don't run it.
mem::forget(max);

// Absent some mitigation like leak amplification, the -1 would incorrectly
// end up in the last position of the returned Vec, with the rest of the
// heap's original contents in front of it in sorted order.
let sorted_vec = heap.into_sorted_vec();
assert!(sorted_vec.is_sorted(), "{:?}", sorted_vec);
}

#[test]
fn test_peek_mut_pop() {
let data = vec![2, 4, 6, 2, 1, 8, 10, 3, 5, 7, 0, 9, 1];
Expand Down
1 change: 1 addition & 0 deletions library/alloc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
#![feature(hasher_prefixfree_extras)]
#![feature(inline_const)]
#![feature(inplace_iteration)]
#![cfg_attr(test, feature(is_sorted))]
#![feature(iter_advance_by)]
#![feature(iter_next_chunk)]
#![feature(iter_repeat_n)]
Expand Down