Skip to content

Commit

Permalink
fix: reproducible Expr.hash (#4033)
Browse files Browse the repository at this point in the history
  • Loading branch information
thatlittleboy committed Jul 17, 2022
1 parent 6e1a7b1 commit 6695cce
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 44 deletions.
9 changes: 5 additions & 4 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum FunctionExpr {
NullCount,
Pow,
#[cfg(feature = "row_hash")]
Hash(usize),
Hash(u64, u64, u64, u64),
#[cfg(feature = "is_in")]
IsIn,
#[cfg(feature = "arg_where")]
Expand Down Expand Up @@ -86,7 +86,7 @@ impl FunctionExpr {
NullCount => with_dtype(IDX_DTYPE),
Pow => float_dtype(),
#[cfg(feature = "row_hash")]
Hash(_) => with_dtype(DataType::UInt64),
Hash(..) => with_dtype(DataType::UInt64),
#[cfg(feature = "is_in")]
IsIn => with_dtype(DataType::Boolean),
#[cfg(feature = "arg_where")]
Expand Down Expand Up @@ -146,10 +146,11 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
wrap!(pow::pow)
}
#[cfg(feature = "row_hash")]
Hash(seed) => {
Hash(k0, k1, k2, k3) => {
let f = move |s: &mut [Series]| {
let s = &s[0];
Ok(s.hash(ahash::RandomState::with_seed(seed)).into_series())
Ok(s.hash(ahash::RandomState::with_seeds(k0, k1, k2, k3))
.into_series())
};
wrap!(f)
}
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2223,8 +2223,8 @@ impl Expr {

#[cfg(feature = "row_hash")]
/// Compute the hash of every element
pub fn hash(self, seed: usize) -> Expr {
self.map_private(FunctionExpr::Hash(seed), "hash")
pub fn hash(self, k0: u64, k1: u64, k2: u64, k3: u64) -> Expr {
self.map_private(FunctionExpr::Hash(k0, k1, k2, k3), "hash")
}

#[cfg(feature = "strings")]
Expand Down
32 changes: 23 additions & 9 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3088,7 +3088,13 @@ def is_between(
"include_bounds should be a boolean or [boolean, boolean]."
)

def hash(self, seed: int = 0, **kwargs: Any) -> Expr:
def hash(
self,
seed: int = 0,
seed_1: int | None = None,
seed_2: int | None = None,
seed_3: int | None = None,
) -> Expr:
"""
Hash the elements in the selection.
Expand All @@ -3097,7 +3103,13 @@ def hash(self, seed: int = 0, **kwargs: Any) -> Expr:
Parameters
----------
seed
The random seed to set.
Random seed parameter. Defaults to 0.
seed_1
Random seed parameter. Defaults to `seed` if not set.
seed_2
Random seed parameter. Defaults to `seed` if not set.
seed_3
Random seed parameter. Defaults to `seed` if not set.
Examples
--------
Expand All @@ -3107,24 +3119,26 @@ def hash(self, seed: int = 0, **kwargs: Any) -> Expr:
... "b": ["x", None, "z"],
... }
... )
>>> df.with_column(pl.all().hash(0)) # doctest: +IGNORE_RESULT
>>> df.with_column(pl.all().hash(10, 20, 30, 40))
shape: (3, 2)
┌──────────────────────┬──────────────────────┐
│ a ┆ b │
│ --- ┆ --- │
│ u64 ┆ u64 │
╞══════════════════════╪══════════════════════╡
115679325277134729216009526192193213963
246171685579122400016174362112783765148
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
5223293301138196316 ┆ 12128596533331663936
13569566217648818014 ┆ 11638928888656214026
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
1212859653333166393697974125653717906
116389288886562140266351727772611549480
└──────────────────────┴──────────────────────┘
"""
# kwargs is for backward compatibility
# can be removed later
return wrap_expr(self._pyexpr.hash(seed))
k0 = seed
k1 = seed_1 if seed_1 is not None else seed
k2 = seed_2 if seed_2 is not None else seed
k3 = seed_3 if seed_3 is not None else seed
return wrap_expr(self._pyexpr.hash(k0, k1, k2, k3))

def reinterpret(self, signed: bool) -> Expr:
"""
Expand Down
37 changes: 23 additions & 14 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5738,8 +5738,13 @@ def take_every(self: DF, n: int) -> DF:
"""
return self.select(pli.col("*").take_every(n))

@deprecated_alias(k0="seed", k1="seed_1", k2="seed_2", k3="seed_3")
def hash_rows(
self, k0: int = 0, k1: int = 1, k2: int = 2, k3: int = 3
self,
seed: int = 0,
seed_1: int | None = None,
seed_2: int | None = None,
seed_3: int | None = None,
) -> pli.Series:
"""
Hash and combine the rows in this DataFrame.
Expand All @@ -5748,14 +5753,14 @@ def hash_rows(
Parameters
----------
k0
Seed parameter.
k1
Seed parameter.
k2
Seed parameter.
k3
Seed parameter.
seed
Random seed parameter. Defaults to 0.
seed_1
Random seed parameter. Defaults to `seed` if not set.
seed_2
Random seed parameter. Defaults to `seed` if not set.
seed_3
Random seed parameter. Defaults to `seed` if not set.
Examples
--------
Expand All @@ -5765,17 +5770,21 @@ def hash_rows(
... "ham": ["a", "b", None, "d"],
... }
... )
>>> df.hash_rows(k0=42)
>>> df.hash_rows(seed=42)
shape: (4,)
Series: '' [u64]
[
13491910696687648691
5223969663565791681
4754614259239603444
162820313037838626
1381515935931787907
14326417405130769253
12561864296213327929
11391467306893437193
]
"""
k0 = seed
k1 = seed_1 if seed_1 is not None else seed
k2 = seed_2 if seed_2 is not None else seed
k3 = seed_3 if seed_3 is not None else seed
return pli.wrap_s(self._df.hash_rows(k0, k1, k2, k3))

def interpolate(self: DF) -> DF:
Expand Down
38 changes: 25 additions & 13 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_datetime_to_pl_timestamp,
_ptr_to_numpy,
_to_python_datetime,
deprecated_alias,
range_to_slice,
)

Expand Down Expand Up @@ -3516,36 +3517,47 @@ def shrink_to_fit(self, in_place: bool = False) -> Series | None:
series._s.shrink_to_fit()
return series

def hash(self, k0: int = 0, k1: int = 1, k2: int = 2, k3: int = 3) -> pli.Series:
@deprecated_alias(k0="seed", k1="seed_1", k2="seed_2", k3="seed_3")
def hash(
self,
seed: int = 0,
seed_1: int | None = None,
seed_2: int | None = None,
seed_3: int | None = None,
) -> pli.Series:
"""
Hash the Series.
The hash value is of type `UInt64`.
Parameters
----------
k0
Seed parameter.
k1
Seed parameter.
k2
Seed parameter.
k3
Seed parameter.
seed
Random seed parameter. Defaults to 0.
seed_1
Random seed parameter. Defaults to `seed` if not set.
seed_2
Random seed parameter. Defaults to `seed` if not set.
seed_3
Random seed parameter. Defaults to `seed` if not set.
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
>>> s.hash(k0=42)
>>> s.hash(seed=42)
shape: (3,)
Series: 'a' [u64]
[
16679613936015749658
17801292685721255234
516997424509290289
89438004737668041
14107061265552512458
15437026767517145468
]
"""
k0 = seed
k1 = seed_1 if seed_1 is not None else seed
k2 = seed_2 if seed_2 is not None else seed
k3 = seed_3 if seed_3 is not None else seed
return wrap_s(self._s.hash(k0, k1, k2, k3))

def reinterpret(self, signed: bool = True) -> Series:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1511,8 +1511,8 @@ impl PyExpr {
pub fn entropy(&self, base: f64, normalize: bool) -> Self {
self.inner.clone().entropy(base, normalize).into()
}
pub fn hash(&self, seed: usize) -> Self {
self.inner.clone().hash(seed).into()
pub fn hash(&self, seed: u64, seed_1: u64, seed_2: u64, seed_3: u64) -> Self {
self.inner.clone().hash(seed, seed_1, seed_2, seed_3).into()
}
}

Expand Down
27 changes: 27 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,33 @@ def test_hash_rows() -> None:
assert df.select([pl.col("a").hash().alias("foo")])["foo"].dtype == pl.UInt64


def test_reproducible_hash_with_seeds() -> None:
"""Tests the reproducibility of DataFrame.hash_rows, Series.hash, and Expr.hash.
cf. issue #3966, hashes must always be reproducible across sessions when using
the same seeds.
"""
df = pl.DataFrame({"s": [1234, None, 5678]})
seeds = (11, 22, 33, 44)
expected = pl.Series(
"s",
[
15801072432137883943,
988796329533502010,
9604537446374444741,
],
dtype=pl.UInt64,
)

result = df.hash_rows(*seeds)
assert_series_equal(expected, result, check_names=False, check_exact=True)
result = df["s"].hash(*seeds)
assert_series_equal(expected, result, check_names=False, check_exact=True)
result = df.select([pl.col("s").hash(*seeds)])["s"]
assert_series_equal(expected, result, check_names=False, check_exact=True)


def test_create_df_from_object() -> None:
class Foo:
def __init__(self, value: int) -> None:
Expand Down

0 comments on commit 6695cce

Please sign in to comment.