Skip to content

Commit

Permalink
Fix HashSet::get_or_insert_with
Browse files Browse the repository at this point in the history
  • Loading branch information
JustForFun88 committed Feb 16, 2023
1 parent 5e4a982 commit 836fbed
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 7 deletions.
5 changes: 5 additions & 0 deletions src/map.rs
Expand Up @@ -4132,6 +4132,11 @@ impl<'a, K, V, S, A: Allocator + Clone> RawVacantEntryMut<'a, K, V, S, A> {
hash_builder: self.hash_builder,
}
}

#[inline]
pub(crate) fn hasher(&self) -> &S {
self.hash_builder
}
}

impl<K, V, S, A: Allocator + Clone> Debug for RawEntryBuilderMut<'_, K, V, S, A> {
Expand Down
127 changes: 120 additions & 7 deletions src/set.rs
Expand Up @@ -8,7 +8,10 @@ use core::iter::{Chain, FromIterator, FusedIterator};
use core::mem;
use core::ops::{BitAnd, BitOr, BitXor, Sub};

use super::map::{self, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner, HashMap, Keys};
use super::map::{
self, make_hash, make_insert_hash, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner,
HashMap, Keys, RawEntryMut,
};
use crate::raw::{Allocator, Global};

// Future Optimization (FIXME!)
Expand Down Expand Up @@ -953,6 +956,12 @@ where
/// Inserts a value computed from `f` into the set if the given `value` is
/// not present, then returns a reference to the value in the set.
///
/// # Panics
///
/// Panics if the value from the function and the provided lookup value
/// are not equivalent or have different hashes. See [`Equivalent`]
/// and [`Hash`] for more information.
///
/// # Examples
///
/// ```
Expand All @@ -967,20 +976,40 @@ where
/// assert_eq!(value, pet);
/// }
/// assert_eq!(set.len(), 4); // a new "fish" was inserted
/// assert!(set.contains("fish"));
/// ```
#[cfg_attr(feature = "inline-more", inline)]
pub fn get_or_insert_with<Q: ?Sized, F>(&mut self, value: &Q, f: F) -> &T
where
Q: Hash + Equivalent<T>,
F: FnOnce(&Q) -> T,
{
#[cold]
#[inline(never)]
fn assert_failed() {
panic!(
"the value from the function and the lookup value \
must be equivalent and have the same hash"
);
}

// Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
// `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
self.map
.raw_entry_mut()
.from_key(value)
.or_insert_with(|| (f(value), ()))
.0
let hash = make_hash::<Q, S>(&self.map.hash_builder, value);
let raw_entry_builder = self.map.raw_entry_mut();
match raw_entry_builder.from_key_hashed_nocheck(hash, value) {
RawEntryMut::Occupied(entry) => entry.into_key(),
RawEntryMut::Vacant(entry) => {
let insert_value = f(value);
let insert_value_hash = make_insert_hash::<T, S>(entry.hasher(), &insert_value);
if !(hash == insert_value_hash && value.equivalent(&insert_value)) {
assert_failed();
}
entry
.insert_hashed_nocheck(insert_value_hash, insert_value, ())
.0
}
}
}

/// Gets the given value's corresponding entry in the set for in-place manipulation.
Expand Down Expand Up @@ -2429,7 +2458,7 @@ fn assert_covariance() {
#[cfg(test)]
mod test_set {
use super::super::map::DefaultHashBuilder;
use super::HashSet;
use super::{make_hash, Equivalent, HashSet};
use std::vec::Vec;

#[test]
Expand Down Expand Up @@ -2886,4 +2915,88 @@ mod test_set {
set.insert(i);
}
}

#[test]
fn duplicate_insert() {
let mut set = HashSet::new();
set.insert(1);
set.get_or_insert_with(&1, |_| 1);
set.get_or_insert_with(&1, |_| 1);
assert!([1].iter().eq(set.iter()));
}

#[test]
#[allow(clippy::derived_hash_with_manual_eq)]
#[should_panic]
fn some_invalid_hash() {
use core::hash::{Hash, Hasher};
#[derive(Eq, PartialEq)]
struct Invalid {
count: u32,
}

struct InvalidRef {
count: u32,
}
impl Equivalent<Invalid> for InvalidRef {
fn equivalent(&self, key: &Invalid) -> bool {
self.count == key.count
}
}
impl Hash for Invalid {
fn hash<H: Hasher>(&self, state: &mut H) {
self.count.hash(state);
}
}
impl Hash for InvalidRef {
fn hash<H: Hasher>(&self, state: &mut H) {
let double = self.count * 2;
double.hash(state);
}
}
let mut set: HashSet<Invalid> = HashSet::new();
let key = InvalidRef { count: 1 };
let value = Invalid { count: 1 };
if key.equivalent(&value) {
set.get_or_insert_with(&key, |_| value);
}
}

#[test]
#[allow(clippy::derived_hash_with_manual_eq)]
#[should_panic]
fn some_invalid_equivalent() {
use core::hash::{Hash, Hasher};
#[derive(Eq, PartialEq)]
struct Invalid {
count: u32,
other: u32,
}

struct InvalidRef {
count: u32,
other: u32,
}
impl Equivalent<Invalid> for InvalidRef {
fn equivalent(&self, key: &Invalid) -> bool {
self.count == key.count && self.other == key.other
}
}
impl Hash for Invalid {
fn hash<H: Hasher>(&self, state: &mut H) {
self.count.hash(state);
}
}
impl Hash for InvalidRef {
fn hash<H: Hasher>(&self, state: &mut H) {
self.count.hash(state);
}
}
let mut set: HashSet<Invalid> = HashSet::new();
let key = InvalidRef { count: 1, other: 1 };
let value = Invalid { count: 1, other: 2 };
if make_hash(set.hasher(), &key) == make_hash(set.hasher(), &value) {
set.get_or_insert_with(&key, |_| value);
}
}
}

0 comments on commit 836fbed

Please sign in to comment.