Skip to content

Commit

Permalink
Auto merge of #105851 - dtolnay:peekmutleak, r=Mark-Simulacrum
Browse files Browse the repository at this point in the history
Leak amplification for peek_mut() to ensure BinaryHeap's invariant is always met

In the libs-api team's discussion around #104210, some of the team had hesitations around exposing malformed BinaryHeaps of an element type whose Ord and Drop impls are trusted, and which does not contain interior mutability.

For example in the context of this kind of code:

```rust
use std::collections::BinaryHeap;
use std::ops::Range;
use std::slice;

fn main() {
    let slice = &mut ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
    let cut_points = BinaryHeap::from(vec![4, 2, 7]);
    println!("{:?}", chop(slice, cut_points));
}

// This is a souped up slice::split_at_mut to split in arbitrary many places.
//
// usize's Ord impl is trusted, so 1 single bounds check guarantees all those
// output slices are non-overlapping and in-bounds
fn chop<T>(slice: &mut [T], mut cut_points: BinaryHeap<usize>) -> Vec<&mut [T]> {
    let mut vec = Vec::with_capacity(cut_points.len() + 1);
    let max = match cut_points.pop() {
        Some(max) => max,
        None => {
            vec.push(slice);
            return vec;
        }
    };

    assert!(max <= slice.len());

    let len = slice.len();
    let ptr: *mut T = slice.as_mut_ptr();
    let get_unchecked_mut = unsafe {
        |range: Range<usize>| &mut *slice::from_raw_parts_mut(ptr.add(range.start), range.len())
    };

    vec.push(get_unchecked_mut(max..len));
    let mut end = max;
    while let Some(start) = cut_points.pop() {
        vec.push(get_unchecked_mut(start..end));
        end = start;
    }
    vec.push(get_unchecked_mut(0..end));
    vec
}
```

```console
[['7', '8', '9'], ['4', '5', '6'], ['2', '3'], ['0', '1']]
```

In the current BinaryHeap API, `peek_mut()` is the only thing that makes the above function unsound.

```rust
let slice = &mut ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
let mut cut_points = BinaryHeap::from(vec![4, 2, 7]);
{
    let mut max = cut_points.peek_mut().unwrap();
    *max = 0;
    std::mem::forget(max);
}
println!("{:?}", chop(slice, cut_points));
```

```console
[['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], [], ['2', '3'], ['0', '1']]
```

Or worse:

```rust
let slice = &mut ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
let mut cut_points = BinaryHeap::from(vec![100, 100]);
{
    let mut max = cut_points.peek_mut().unwrap();
    *max = 0;
    std::mem::forget(max);
}
println!("{:?}", chop(slice, cut_points));
```

```console
[['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], [], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '\u{1}', '\0', '?', '翾', '?', '翾', '\0', '\0', '?', '翾', '?', '翾', '?', '啿', '?', '啿', '?', '啿', '?', '啿', '?', '啿', '?', '翾', '\0', '\0', '񤬐', '啿', '\u{5}', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\u{8}', '\0', '`@',` '\0', '\u{1}', '\0', '?', '翾', '?', '翾', '?', '翾', '
thread 'main' panicked at 'index out of bounds: the len is 33 but the index is 33', library/core/src/unicode/unicode_data.rs:319:9
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
```

---

This PR makes `peek_mut()` use leak amplification (https://doc.rust-lang.org/1.66.0/nomicon/leaking.html#drain) to preserve the heap's invariant even in the situation that `PeekMut` gets leaked.

I'll also follow up in the tracking issue of unstable `drain_sorted()` (#59278) and `retain()` (#71503).
  • Loading branch information
bors committed Jan 15, 2023
2 parents 754f6d4 + 2350170 commit bbb36fe
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 10 deletions.
65 changes: 55 additions & 10 deletions library/alloc/src/collections/binary_heap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
use core::fmt;
use core::iter::{FromIterator, FusedIterator, InPlaceIterable, SourceIter, TrustedLen};
use core::mem::{self, swap, ManuallyDrop};
use core::num::NonZeroUsize;
use core::ops::{Deref, DerefMut};
use core::ptr;

Expand All @@ -165,12 +166,20 @@ mod tests;
/// It is a logic error for an item to be modified in such a way that the
/// item's ordering relative to any other item, as determined by the [`Ord`]
/// trait, changes while it is in the heap. This is normally only possible
/// through [`Cell`], [`RefCell`], global state, I/O, or unsafe code. The
/// through interior mutability, global state, I/O, or unsafe code. The
/// behavior resulting from such a logic error is not specified, but will
/// be encapsulated to the `BinaryHeap` that observed the logic error and not
/// result in undefined behavior. This could include panics, incorrect results,
/// aborts, memory leaks, and non-termination.
///
/// As long as no elements change their relative order while being in the heap
/// as described above, the API of `BinaryHeap` guarantees that the heap
/// invariant remains intact i.e. its methods all behave as documented. For
/// example if a method is documented as iterating in sorted order, that's
/// guaranteed to work as long as elements in the heap have not changed order,
/// even in the presence of closures getting unwinded out of, iterators getting
/// leaked, and similar foolishness.
///
/// # Examples
///
/// ```
Expand Down Expand Up @@ -279,7 +288,9 @@ pub struct BinaryHeap<T> {
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
pub struct PeekMut<'a, T: 'a + Ord> {
heap: &'a mut BinaryHeap<T>,
sift: bool,
// If a set_len + sift_down are required, this is Some. If a &mut T has not
// yet been exposed to peek_mut()'s caller, it's None.
original_len: Option<NonZeroUsize>,
}

#[stable(feature = "collection_debug", since = "1.17.0")]
Expand All @@ -292,7 +303,14 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
impl<T: Ord> Drop for PeekMut<'_, T> {
fn drop(&mut self) {
if self.sift {
if let Some(original_len) = self.original_len {
// SAFETY: That's how many elements were in the Vec at the time of
// the PeekMut::deref_mut call, and therefore also at the time of
// the BinaryHeap::peek_mut call. Since the PeekMut did not end up
// getting leaked, we are now undoing the leak amplification that
// the DerefMut prepared for.
unsafe { self.heap.data.set_len(original_len.get()) };

// SAFETY: PeekMut is only instantiated for non-empty heaps.
unsafe { self.heap.sift_down(0) };
}
Expand All @@ -313,7 +331,26 @@ impl<T: Ord> Deref for PeekMut<'_, T> {
impl<T: Ord> DerefMut for PeekMut<'_, T> {
fn deref_mut(&mut self) -> &mut T {
debug_assert!(!self.heap.is_empty());
self.sift = true;

let len = self.heap.len();
if len > 1 {
// Here we preemptively leak all the rest of the underlying vector
// after the currently max element. If the caller mutates the &mut T
// we're about to give them, and then leaks the PeekMut, all these
// elements will remain leaked. If they don't leak the PeekMut, then
// either Drop or PeekMut::pop will un-leak the vector elements.
//
// This is technique is described throughout several other places in
// the standard library as "leak amplification".
unsafe {
// SAFETY: len > 1 so len != 0.
self.original_len = Some(NonZeroUsize::new_unchecked(len));
// SAFETY: len > 1 so all this does for now is leak elements,
// which is safe.
self.heap.data.set_len(1);
}
}

// SAFE: PeekMut is only instantiated for non-empty heaps
unsafe { self.heap.data.get_unchecked_mut(0) }
}
Expand All @@ -323,9 +360,16 @@ impl<'a, T: Ord> PeekMut<'a, T> {
/// Removes the peeked value from the heap and returns it.
#[stable(feature = "binary_heap_peek_mut_pop", since = "1.18.0")]
pub fn pop(mut this: PeekMut<'a, T>) -> T {
let value = this.heap.pop().unwrap();
this.sift = false;
value
if let Some(original_len) = this.original_len.take() {
// SAFETY: This is how many elements were in the Vec at the time of
// the BinaryHeap::peek_mut call.
unsafe { this.heap.data.set_len(original_len.get()) };

// Unlike in Drop, here we don't also need to do a sift_down even if
// the caller could've mutated the element. It is removed from the
// heap on the next line and pop() is not sensitive to its value.
}
this.heap.pop().unwrap()
}
}

Expand Down Expand Up @@ -398,8 +442,9 @@ impl<T: Ord> BinaryHeap<T> {
/// Returns a mutable reference to the greatest item in the binary heap, or
/// `None` if it is empty.
///
/// Note: If the `PeekMut` value is leaked, the heap may be in an
/// inconsistent state.
/// Note: If the `PeekMut` value is leaked, some heap elements might get
/// leaked along with it, but the remaining elements will remain a valid
/// heap.
///
/// # Examples
///
Expand All @@ -426,7 +471,7 @@ impl<T: Ord> BinaryHeap<T> {
/// otherwise it's *O*(1).
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T>> {
if self.is_empty() { None } else { Some(PeekMut { heap: self, sift: false }) }
if self.is_empty() { None } else { Some(PeekMut { heap: self, original_len: None }) }
}

/// Removes the greatest item from the binary heap and returns it, or `None` if it
Expand Down
19 changes: 19 additions & 0 deletions library/alloc/src/collections/binary_heap/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::*;
use crate::boxed::Box;
use crate::testing::crash_test::{CrashTestDummy, Panic};
use core::mem;
use std::iter::TrustedLen;
use std::panic::{catch_unwind, AssertUnwindSafe};

Expand Down Expand Up @@ -146,6 +147,24 @@ fn test_peek_mut() {
assert_eq!(heap.peek(), Some(&9));
}

#[test]
fn test_peek_mut_leek() {
let data = vec![4, 2, 7];
let mut heap = BinaryHeap::from(data);
let mut max = heap.peek_mut().unwrap();
*max = -1;

// The PeekMut object's Drop impl would have been responsible for moving the
// -1 out of the max position of the BinaryHeap, but we don't run it.
mem::forget(max);

// Absent some mitigation like leak amplification, the -1 would incorrectly
// end up in the last position of the returned Vec, with the rest of the
// heap's original contents in front of it in sorted order.
let sorted_vec = heap.into_sorted_vec();
assert!(sorted_vec.is_sorted(), "{:?}", sorted_vec);
}

#[test]
fn test_peek_mut_pop() {
let data = vec![2, 4, 6, 2, 1, 8, 10, 3, 5, 7, 0, 9, 1];
Expand Down
1 change: 1 addition & 0 deletions library/alloc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
#![feature(hasher_prefixfree_extras)]
#![feature(inline_const)]
#![feature(inplace_iteration)]
#![cfg_attr(test, feature(is_sorted))]
#![feature(iter_advance_by)]
#![feature(iter_next_chunk)]
#![feature(iter_repeat_n)]
Expand Down

0 comments on commit bbb36fe

Please sign in to comment.