Skip to content

Commit

Permalink
feat[rust]: Add Expr.sample_n (#4623)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 31, 2022
1 parent 355a847 commit 85aa94b
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 26 deletions.
9 changes: 6 additions & 3 deletions polars/polars-core/src/chunked_array/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ impl Series {
) -> Result<Self> {
if !with_replacement && n > self.len() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
"cannot take a larger sample than the total population when `with_replacement=false`"
.into(),
));
}
if n == 0 {
Expand Down Expand Up @@ -126,7 +127,8 @@ where
) -> Result<Self> {
if !with_replacement && n > self.len() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
"cannot take a larger sample than the total population when `with_replacement=false`"
.into(),
));
}
let len = self.len();
Expand Down Expand Up @@ -171,7 +173,8 @@ impl DataFrame {
) -> Result<Self> {
if !with_replacement && n > self.height() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
"cannot take a larger sample than the total population when `with_replacement=false`"
.into(),
));
}
// all columns should used the same indices. So we first create the indices.
Expand Down
15 changes: 15 additions & 0 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,21 @@ impl Expr {
.with_fmt("shuffle")
}

#[cfg(feature = "random")]
pub fn sample_n(
self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.apply(
move |s| s.sample_n(n, with_replacement, shuffle, seed),
GetOutput::same_type(),
)
.with_fmt("sample_n")
}

#[cfg(feature = "random")]
pub fn sample_frac(
self,
Expand Down
15 changes: 8 additions & 7 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5599,20 +5599,22 @@ def sample(
seed: int | None = None,
) -> DF:
"""
Sample from this DataFrame by setting either `n` or `frac`.
Sample from this DataFrame.
Parameters
----------
n
Number of samples < self.len() .
Number of items to return. Cannot be used with `frac`. Defaults to 1 if
`frac` is None.
frac
Fraction between 0.0 and 1.0 .
Fraction of items to return. Cannot be used with `n`.
with_replacement
Sample with replacement.
Allow values to be sampled more than once.
shuffle
Shuffle the order of sampled data points.
seed
Initialization seed. If None is given a random seed is used.
Seed for the random number generator. If set to None (default), a random
seed is used.
Examples
--------
Expand All @@ -5637,7 +5639,7 @@ def sample(
"""
if n is not None and frac is not None:
raise ValueError("n and frac were both supplied")
raise ValueError("cannot specify both `n` and `frac`")

if n is None and frac is not None:
return self._from_pydf(
Expand All @@ -5646,7 +5648,6 @@ def sample(

if n is None:
n = 1

return self._from_pydf(self._df.sample_n(n, with_replacement, shuffle, seed))

def fold(
Expand Down
38 changes: 30 additions & 8 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, Sequence
from warnings import warn

from polars import internals as pli
from polars.datatypes import (
Expand Down Expand Up @@ -5393,31 +5394,36 @@ def shuffle(self, seed: int | None = None) -> Expr:
seed = random.randint(0, 10000)
return wrap_expr(self._pyexpr.shuffle(seed))

@deprecated_alias(fraction="frac")
def sample(
self,
fraction: float = 1.0,
frac: float | None = None,
with_replacement: bool = True,
shuffle: bool = False,
seed: int | None = None,
n: int | None = None,
) -> Expr:
"""
Sample a fraction of the `Series`.
Sample from this expression.
Parameters
----------
fraction
Fraction 0.0 <= value <= 1.0
frac
Fraction of items to return. Cannot be used with `n`.
with_replacement
Allow values to be sampled more than once.
seed
Seed initialization. If None given a random seed is used.
shuffle
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is used.
n
Number of items to return. Cannot be used with `frac`.
Examples
--------
>>> df = pl.DataFrame({"a": [1, 2, 3]})
>>> df.select(pl.col("a").sample(seed=1))
>>> df.select(pl.col("a").sample(frac=1.0, with_replacement=True, seed=1))
shape: (3, 1)
┌─────┐
│ a │
Expand All @@ -5432,8 +5438,24 @@ def sample(
└─────┘
"""
warn(
"The function signature for Expr.sample will change in a future"
" version. Explicitly set `frac` and `with_replacement` using keyword"
" arguments to retain the same behaviour.",
FutureWarning,
stacklevel=2,
)

if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if n is not None and frac is None:
return wrap_expr(self._pyexpr.sample_n(n, with_replacement, shuffle, seed))

if frac is None:
frac = 1.0
return wrap_expr(
self._pyexpr.sample_frac(fraction, with_replacement, shuffle, seed)
self._pyexpr.sample_frac(frac, with_replacement, shuffle, seed)
)

def ewm_mean(
Expand Down
15 changes: 8 additions & 7 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3539,20 +3539,22 @@ def sample(
seed: int | None = None,
) -> Series:
"""
Sample from this Series by setting either `n` or `frac`.
Sample from this Series.
Parameters
----------
n
Number of samples < self.len().
Number of items to return. Cannot be used with `frac`. Defaults to 1 if
`frac` is None.
frac
Fraction between 0.0 and 1.0 .
Fraction of items to return. Cannot be used with `n`.
with_replacement
sample with replacement.
Allow values to be sampled more than once.
shuffle
Shuffle the order of sampled data points.
seed
Initialization seed. If None is given a random seed is used.
Seed for the random number generator. If set to None (default), a random
seed is used.
Examples
--------
Expand All @@ -3567,14 +3569,13 @@ def sample(
"""
if n is not None and frac is not None:
raise ValueError("n and frac were both supplied")
raise ValueError("cannot specify both `n` and `frac`")

if n is None and frac is not None:
return wrap_s(self._s.sample_frac(frac, with_replacement, shuffle, seed))

if n is None:
n = 1

return wrap_s(self._s.sample_n(n, with_replacement, shuffle, seed))

def peak_max(self) -> Series:
Expand Down
13 changes: 13 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,19 @@ impl PyExpr {
self.inner.clone().shuffle(seed).into()
}

pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.inner
.clone()
.sample_n(n, with_replacement, shuffle, seed)
.into()
}

pub fn sample_frac(
&self,
frac: f64,
Expand Down
11 changes: 10 additions & 1 deletion py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import cast

import numpy as np
import pytest

import polars as pl
from polars.testing import assert_series_equal, verify_series_and_expr_api
Expand Down Expand Up @@ -93,9 +94,17 @@ def test_count_expr() -> None:
assert out["count"].to_list() == [4, 1]


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_sample() -> None:
a = pl.Series("a", range(0, 20))
out = pl.select(pl.lit(a).sample(0.5, False, seed=1)).to_series()
out = pl.select(
pl.lit(a).sample(frac=0.5, with_replacement=False, seed=1)
).to_series()
assert out.shape == (10,)
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)

out = pl.select(pl.lit(a).sample(n=10, with_replacement=False, seed=1)).to_series()
assert out.shape == (10,)
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)
Expand Down

0 comments on commit 85aa94b

Please sign in to comment.