Skip to content

Commit

Permalink
feat[python]!: Update Expr.sample signature and change random seedi…
Browse files Browse the repository at this point in the history
…ng (#4641)
  • Loading branch information
stinodego committed Aug 31, 2022
1 parent 434dd8e commit 50eba3d
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 63 deletions.
6 changes: 5 additions & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import os
import random
import sys
from io import BytesIO, IOBase, StringIO
from pathlib import Path
Expand Down Expand Up @@ -5614,7 +5615,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is used.
seed is generated using the ``random`` module.
Examples
--------
Expand All @@ -5641,6 +5642,9 @@ def sample(
if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if seed is None:
seed = random.randint(0, 10000)

if n is None and frac is not None:
return self._from_pydf(
self._df.sample_frac(frac, with_replacement, shuffle, seed)
Expand Down
38 changes: 16 additions & 22 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import random
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, Sequence
from warnings import warn

from polars import internals as pli
from polars.datatypes import (
Expand Down Expand Up @@ -5394,20 +5393,22 @@ def shuffle(self, seed: int | None = None) -> Expr:
seed = random.randint(0, 10000)
return wrap_expr(self._pyexpr.shuffle(seed))

@deprecated_alias(fraction="frac")
def sample(
self,
n: int | None = None,
frac: float | None = None,
with_replacement: bool = True,
with_replacement: bool = False,
shuffle: bool = False,
seed: int | None = None,
n: int | None = None,
) -> Expr:
"""
Sample from this expression.
Parameters
----------
n
Number of items to return. Cannot be used with `frac`. Defaults to 1 if
`frac` is None.
frac
Fraction of items to return. Cannot be used with `n`.
with_replacement
Expand All @@ -5416,9 +5417,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is used.
n
Number of items to return. Cannot be used with `frac`.
seed is generated using the ``random`` module.
Examples
--------
Expand All @@ -5438,25 +5437,20 @@ def sample(
└─────┘
"""
warn(
"The function signature for Expr.sample will change in a future"
" version. Explicitly set `frac` and `with_replacement` using keyword"
" arguments to retain the same behaviour.",
FutureWarning,
stacklevel=2,
)

if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if n is not None and frac is None:
return wrap_expr(self._pyexpr.sample_n(n, with_replacement, shuffle, seed))
if seed is None:
seed = random.randint(0, 10000)

if frac is None:
frac = 1.0
return wrap_expr(
self._pyexpr.sample_frac(frac, with_replacement, shuffle, seed)
)
if frac is not None:
return wrap_expr(
self._pyexpr.sample_frac(frac, with_replacement, shuffle, seed)
)

if n is None:
n = 1
return wrap_expr(self._pyexpr.sample_n(n, with_replacement, shuffle, seed))

def ewm_mean(
self,
Expand Down
11 changes: 1 addition & 10 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3554,7 +3554,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is used.
seed is generated using the ``random`` module.
Examples
--------
Expand All @@ -3568,15 +3568,6 @@ def sample(
]
"""
if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if n is None and frac is not None:
return wrap_s(self._s.sample_frac(frac, with_replacement, shuffle, seed))

if n is None:
n = 1
return wrap_s(self._s.sample_n(n, with_replacement, shuffle, seed))

def peak_max(self) -> Series:
"""
Expand Down
28 changes: 0 additions & 28 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,34 +558,6 @@ impl PySeries {
self.series.has_validity()
}

pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_n(n, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}

pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_frac(frac, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}

pub fn series_equal(&self, other: &PySeries, null_equal: bool, strict: bool) -> bool {
if strict {
self.series.eq(&other.series)
Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import random
from typing import cast

import numpy as np
import pytest

import polars as pl
from polars.testing import assert_series_equal, verify_series_and_expr_api
Expand Down Expand Up @@ -94,7 +94,6 @@ def test_count_expr() -> None:
assert out["count"].to_list() == [4, 1]


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_sample() -> None:
a = pl.Series("a", range(0, 20))
out = pl.select(
Expand All @@ -109,6 +108,13 @@ def test_sample() -> None:
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)

# Setting random.seed should lead to reproducible results
random.seed(1)
result1 = pl.select(pl.lit(a).sample(n=10)).to_series()
random.seed(1)
result2 = pl.select(pl.lit(a).sample(n=10)).to_series()
assert result1.series_equal(result2)


def test_map_alias() -> None:
out = pl.DataFrame({"foo": [1, 2, 3]}).select(
Expand Down

0 comments on commit 50eba3d

Please sign in to comment.