Skip to content

Commit

Permalink
Add optional seeding for sampling (#3080)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnpryer committed Apr 8, 2022
1 parent 837a548 commit 90c2d85
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 46 deletions.
4 changes: 2 additions & 2 deletions nodejs-polars/src/dataframe/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub(crate) fn sample_n(cx: CallContext) -> JsResult<JsExternal> {
let df = params.get_external::<DataFrame>(&cx, "_df")?;
let n = params.get_as::<usize>("n")?;
let with_replacement = params.get_as::<bool>("withReplacement")?;
df.sample_n(n, with_replacement, 0)
df.sample_n(n, with_replacement, Some(0))
.map_err(JsPolarsEr::from)?
.try_into_js(&cx)
}
Expand All @@ -77,7 +77,7 @@ pub(crate) fn sample_frac(cx: CallContext) -> JsResult<JsExternal> {
let df = params.get_external::<DataFrame>(&cx, "_df")?;
let frac = params.get_as::<f64>("frac")?;
let with_replacement = params.get_as::<bool>("withReplacement")?;
df.sample_frac(frac, with_replacement, 0)
df.sample_frac(frac, with_replacement, Some(0))
.map_err(JsPolarsEr::from)?
.try_into_js(&cx)
}
Expand Down
4 changes: 2 additions & 2 deletions nodejs-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ pub(crate) fn sample_n(cx: CallContext) -> JsResult<JsExternal> {
let n = params.get_as::<usize>("n")?;
let with_replacement = params.get_as::<bool>("withReplacement")?;
series
.sample_n(n, with_replacement, 0)
.sample_n(n, with_replacement, Some(0))
.map_err(JsPolarsEr::from)?
.try_into_js(&cx)
}
Expand All @@ -1192,7 +1192,7 @@ pub(crate) fn sample_frac(cx: CallContext) -> JsResult<JsExternal> {
let frac = params.get_as::<f64>("frac")?;
let with_replacement = params.get_as::<bool>("withReplacement")?;
series
.sample_frac(frac, with_replacement, 0)
.sample_frac(frac, with_replacement, Some(0))
.map_err(JsPolarsEr::from)?
.try_into_js(&cx)
}
Expand Down
67 changes: 50 additions & 17 deletions polars/polars-core/src/chunked_array/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@ use rand::distributions::Bernoulli;
use rand::prelude::*;
use rand_distr::{Distribution, Normal, Standard, StandardNormal, Uniform};

fn create_rand_index_with_replacement(n: usize, len: usize, seed: u64) -> IdxCa {
let mut rng = SmallRng::seed_from_u64(seed);
fn get_random_seed() -> u64 {
let mut rng = SmallRng::from_entropy();

rng.next_u64()
}

fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));
(0..n as IdxSize)
.map(move |_| Uniform::new(0, len as IdxSize).sample(&mut rng))
.collect_trusted::<NoNull<IdxCa>>()
.into_inner()
}

fn create_rand_index_no_replacement(n: usize, len: usize, seed: u64) -> IdxCa {
let mut rng = SmallRng::seed_from_u64(seed);
fn create_rand_index_no_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));
let mut idx = Vec::from_iter_trusted_length(0..len as IdxSize);
idx.shuffle(&mut rng);
idx.truncate(n);
Expand All @@ -41,7 +47,7 @@ where
}

impl Series {
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: u64) -> Result<Self> {
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> Result<Self> {
if !with_replacement && n > self.len() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
Expand All @@ -66,13 +72,18 @@ impl Series {
}

/// Sample a fraction between 0.0-1.0 of this ChunkedArray.
pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: u64) -> Result<Self> {
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
seed: Option<u64>,
) -> Result<Self> {
let n = (self.len() as f64 * frac) as usize;
self.sample_n(n, with_replacement, seed)
}

pub fn shuffle(&self, seed: u64) -> Self {
self.sample_n(self.len(), false, seed).unwrap()
self.sample_n(self.len(), false, Some(seed)).unwrap()
}
}

Expand All @@ -81,7 +92,7 @@ where
ChunkedArray<T>: ChunkTake,
{
/// Sample n datapoints from this ChunkedArray.
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: u64) -> Result<Self> {
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> Result<Self> {
if !with_replacement && n > self.len() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
Expand All @@ -106,15 +117,20 @@ where
}

/// Sample a fraction between 0.0-1.0 of this ChunkedArray.
pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: u64) -> Result<Self> {
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
seed: Option<u64>,
) -> Result<Self> {
let n = (self.len() as f64 * frac) as usize;
self.sample_n(n, with_replacement, seed)
}
}

impl DataFrame {
/// Sample n datapoints from this DataFrame.
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: u64) -> Result<Self> {
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> Result<Self> {
if !with_replacement && n > self.height() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
Expand All @@ -131,7 +147,12 @@ impl DataFrame {
}

/// Sample a fraction between 0.0-1.0 of this DataFrame.
pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: u64) -> Result<Self> {
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
seed: Option<u64>,
) -> Result<Self> {
let n = (self.height() as f64 * frac) as usize;
self.sample_n(n, with_replacement, seed)
}
Expand Down Expand Up @@ -212,13 +233,25 @@ mod test {
]
.unwrap();

assert!(df.sample_n(3, false, 0).is_ok());
assert!(df.sample_frac(0.4, false, 0).is_ok());
// default samples are random and don't require seeds
assert!(df.sample_n(3, false, None).is_ok());
assert!(df.sample_frac(0.4, false, None).is_ok());
assert!(!df
.sample_n(3, false, None)
.unwrap()
.frame_equal(&df.sample_n(3, false, None).unwrap()));
assert!(!df
.sample_frac(0.4, false, None)
.unwrap()
.frame_equal(&df.sample_frac(0.4, false, None).unwrap()));
// with seeding
assert!(df.sample_n(3, false, Some(0)).is_ok());
assert!(df.sample_frac(0.4, false, Some(0)).is_ok());
// without replacement can not sample more than 100%
assert!(df.sample_frac(2.0, false, 0).is_err());
assert!(df.sample_n(3, true, 0).is_ok());
assert!(df.sample_frac(0.4, true, 0).is_ok());
assert!(df.sample_frac(2.0, false, Some(0)).is_err());
assert!(df.sample_n(3, true, Some(0)).is_ok());
assert!(df.sample_frac(0.4, true, Some(0)).is_ok());
// with replacement can sample more than 100%
assert!(df.sample_frac(2.0, true, 0).is_ok());
assert!(df.sample_frac(2.0, true, Some(0)).is_ok());
}
}
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1971,7 +1971,7 @@ impl Expr {
}

#[cfg(feature = "random")]
pub fn sample_frac(self, frac: f64, with_replacement: bool, seed: u64) -> Self {
pub fn sample_frac(self, frac: f64, with_replacement: bool, seed: Option<u64>) -> Self {
self.apply(
move |s| s.sample_frac(frac, with_replacement, seed),
GetOutput::same_type(),
Expand Down
6 changes: 2 additions & 4 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,7 +2634,7 @@ def sample(
self,
fraction: float = 1.0,
with_replacement: bool = True,
seed: Optional[int] = 0,
seed: Optional[int] = None,
) -> "Expr":
"""
Sample a fraction of the `Series`.
Expand All @@ -2646,10 +2646,8 @@ def sample(
with_replacement
Allow values to be sampled more than once.
seed
Seed initialization. If None given numpy is used.
Seed initialization. If None given a random seed is used.
"""
if seed is None:
seed = int(np.random.randint(0, 10000))
return wrap_expr(self._pyexpr.sample_frac(fraction, with_replacement, seed))

def ewm_mean(
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4728,7 +4728,7 @@ def sample(
n: Optional[int] = None,
frac: Optional[float] = None,
with_replacement: bool = False,
seed: int = 0,
seed: Optional[int] = None,
) -> DF:
"""
Sample from this DataFrame by setting either `n` or `frac`.
Expand All @@ -4742,7 +4742,7 @@ def sample(
with_replacement
Sample with replacement.
seed
Initialization seed
Initialization seed. If None is given a random seed is used.
Examples
--------
Expand All @@ -4753,7 +4753,7 @@ def sample(
... "ham": ["a", "b", "c"],
... }
... )
>>> df.sample(n=2) # doctest: +IGNORE_RESULT
>>> df.sample(n=2, seed=0) # doctest: +IGNORE_RESULT
shape: (2, 3)
┌─────┬─────┬─────┐
│ foo ┆ bar ┆ ham │
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3017,7 +3017,7 @@ def sample(
n: Optional[int] = None,
frac: Optional[float] = None,
with_replacement: bool = False,
seed: int = 0,
seed: Optional[int] = None,
) -> "Series":
"""
Sample from this Series by setting either `n` or `frac`.
Expand All @@ -3031,12 +3031,12 @@ def sample(
with_replacement
sample with replacement.
seed
Initialization seed
Initialization seed. If None is given a random seed is used.
Examples
--------
>>> s = pl.Series("a", [1, 2, 3, 4, 5])
>>> s.sample(2) # doctest: +IGNORE_RESULT
>>> s.sample(2, seed=0) # doctest: +IGNORE_RESULT
shape: (2,)
Series: 'a' [i64]
[
Expand Down
9 changes: 7 additions & 2 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,15 +658,20 @@ impl PyDataFrame {
Ok(df.into())
}

pub fn sample_n(&self, n: usize, with_replacement: bool, seed: u64) -> PyResult<Self> {
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> PyResult<Self> {
let df = self
.df
.sample_n(n, with_replacement, seed)
.map_err(PyPolarsErr::from)?;
Ok(df.into())
}

pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: u64) -> PyResult<Self> {
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let df = self
.df
.sample_frac(frac, with_replacement, seed)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,7 @@ impl PyExpr {
self.inner.clone().shuffle(seed).into()
}

pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: u64) -> Self {
pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: Option<u64>) -> Self {
self.inner
.clone()
.sample_frac(frac, with_replacement, seed)
Expand Down
9 changes: 7 additions & 2 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,20 @@ impl PySeries {
Ok(ca.into_series().into())
}

pub fn sample_n(&self, n: usize, with_replacement: bool, seed: u64) -> PyResult<Self> {
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> PyResult<Self> {
let s = self
.series
.sample_n(n, with_replacement, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}

pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: u64) -> PyResult<Self> {
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_frac(frac, with_replacement, seed)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/db-benchmark/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
maximum = 130352833.0

for _ in range(10):
permuted = df.sample(frac=1.0)
permuted = df.sample(frac=1.0, seed=0)
computed = permuted.select(
[pl.col("id").min().alias("min"), pl.col("id").max().alias("max")]
)
Expand Down
8 changes: 6 additions & 2 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,8 +1572,12 @@ def test_is_unique() -> None:

def test_sample() -> None:
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, 7, 8], "ham": ["a", "b", "c"]})
assert df.sample(n=2).shape == (2, 3)
assert df.sample(frac=0.4).shape == (1, 3)

# by default samples should be random
assert df.sample(n=2) != df.sample(n=2)

assert df.sample(n=2, seed=0).shape == (2, 3)
assert df.sample(frac=0.4, seed=0).shape == (1, 3)


@pytest.mark.parametrize("in_place", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_count_expr() -> None:

def test_sample() -> None:
a = pl.Series("a", range(0, 20))
out = pl.select(pl.lit(a).sample(0.5, False, 1)).to_series()
out = pl.select(pl.lit(a).sample(0.5, False, seed=1)).to_series()
assert out.shape == (10,)
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)
Expand Down
14 changes: 9 additions & 5 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,16 +1082,20 @@ def test_dot() -> None:

def test_sample() -> None:
s = pl.Series("a", [1, 2, 3, 4, 5])
assert len(s.sample(n=2)) == 2
assert len(s.sample(frac=0.4)) == 2

assert len(s.sample(n=2, with_replacement=True)) == 2
# by default samples should be random
assert s.sample(n=2) != s.sample(n=2)

assert len(s.sample(n=2, seed=0)) == 2
assert len(s.sample(frac=0.4, seed=0)) == 2

assert len(s.sample(n=2, with_replacement=True, seed=0)) == 2

# on a series of length 5, you cannot sample more than 5 items
with pytest.raises(Exception):
s.sample(n=10, with_replacement=False)
s.sample(n=10, with_replacement=False, seed=0)
# unless you use with_replacement=True
assert len(s.sample(n=10, with_replacement=True)) == 10
assert len(s.sample(n=10, with_replacement=True, seed=0)) == 10


def test_peak_max_peak_min() -> None:
Expand Down

0 comments on commit 90c2d85

Please sign in to comment.