Skip to content

Commit

Permalink
feat(python)!: Update Expr.sample signature and change random seedi…
Browse files Browse the repository at this point in the history
…ng (#4648)
  • Loading branch information
stinodego committed Nov 23, 2022
1 parent c0fdc04 commit c71d7b5
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 62 deletions.
6 changes: 5 additions & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
import os
import random
import sys
import typing
from collections.abc import Sized
Expand Down Expand Up @@ -6196,7 +6197,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is used.
seed is generated using the ``random`` module.
Examples
--------
Expand All @@ -6223,6 +6224,9 @@ def sample(
if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if seed is None:
seed = random.randint(0, 10000)

if n is None and frac is not None:
return self._from_pydf(
self._df.sample_frac(frac, with_replacement, shuffle, seed)
Expand Down
38 changes: 16 additions & 22 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Sequence, cast
from warnings import warn

from polars import internals as pli
from polars.datatypes import (
Expand Down Expand Up @@ -5540,20 +5539,22 @@ def shuffle(self, seed: int | None = None) -> Expr:
seed = random.randint(0, 10000)
return wrap_expr(self._pyexpr.shuffle(seed))

@deprecated_alias(fraction="frac")
def sample(
self,
n: int | None = None,
frac: float | None = None,
with_replacement: bool = True,
with_replacement: bool = False,
shuffle: bool = False,
seed: int | None = None,
n: int | None = None,
) -> Expr:
"""
Sample from this expression.
Parameters
----------
n
Number of items to return. Cannot be used with `frac`. Defaults to 1 if
`frac` is None.
frac
Fraction of items to return. Cannot be used with `n`.
with_replacement
Expand All @@ -5562,9 +5563,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is used.
n
Number of items to return. Cannot be used with `frac`.
seed is generated using the ``random`` module.
Examples
--------
Expand All @@ -5584,25 +5583,20 @@ def sample(
└─────┘
"""
warn(
"The function signature for Expr.sample will change in a future"
" version. Explicitly set `frac` and `with_replacement` using keyword"
" arguments to retain the same behaviour.",
FutureWarning,
stacklevel=2,
)

if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if n is not None and frac is None:
return wrap_expr(self._pyexpr.sample_n(n, with_replacement, shuffle, seed))
if seed is None:
seed = random.randint(0, 10000)

if frac is None:
frac = 1.0
return wrap_expr(
self._pyexpr.sample_frac(frac, with_replacement, shuffle, seed)
)
if frac is not None:
return wrap_expr(
self._pyexpr.sample_frac(frac, with_replacement, shuffle, seed)
)

if n is None:
n = 1
return wrap_expr(self._pyexpr.sample_n(n, with_replacement, shuffle, seed))

def ewm_mean(
self,
Expand Down
11 changes: 1 addition & 10 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4014,7 +4014,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is used.
seed is generated using the ``random`` module.
Examples
--------
Expand All @@ -4028,15 +4028,6 @@ def sample(
]
"""
if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if n is None and frac is not None:
return wrap_s(self._s.sample_frac(frac, with_replacement, shuffle, seed))

if n is None:
n = 1
return wrap_s(self._s.sample_n(n, with_replacement, shuffle, seed))

def peak_max(self) -> Series:
"""
Expand Down
28 changes: 0 additions & 28 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,34 +566,6 @@ impl PySeries {
self.series.has_validity()
}

pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_n(n, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}

pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_frac(frac, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}

pub fn series_equal(&self, other: &PySeries, null_equal: bool, strict: bool) -> bool {
if strict {
self.series.eq(&other.series)
Expand Down
8 changes: 7 additions & 1 deletion py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def test_shuffle() -> None:
assert result1.series_equal(result2)


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_sample() -> None:
a = pl.Series("a", range(0, 20))
out = pl.select(
Expand All @@ -123,6 +122,13 @@ def test_sample() -> None:
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)

# Setting random.seed should lead to reproducible results
random.seed(1)
result1 = pl.select(pl.lit(a).sample(n=10)).to_series()
random.seed(1)
result2 = pl.select(pl.lit(a).sample(n=10)).to_series()
assert result1.series_equal(result2)


def test_map_alias() -> None:
out = pl.DataFrame({"foo": [1, 2, 3]}).select(
Expand Down

0 comments on commit c71d7b5

Please sign in to comment.