Skip to content

Commit

Permalink
Revert "feat[python]!: Update Expr.sample signature and change rand…
Browse files Browse the repository at this point in the history
…om seeding (#4641)" (#4643)

This reverts commit 50eba3d.
  • Loading branch information
ritchie46 committed Aug 31, 2022
1 parent 0433e98 commit cae8e17
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 30 deletions.
6 changes: 1 addition & 5 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import os
import random
import sys
from io import BytesIO, IOBase, StringIO
from pathlib import Path
Expand Down Expand Up @@ -5615,7 +5614,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
seed is used.
Examples
--------
Expand All @@ -5642,9 +5641,6 @@ def sample(
if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if seed is None:
seed = random.randint(0, 10000)

if n is None and frac is not None:
return self._from_pydf(
self._df.sample_frac(frac, with_replacement, shuffle, seed)
Expand Down
38 changes: 22 additions & 16 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, Sequence
from warnings import warn

from polars import internals as pli
from polars.datatypes import (
Expand Down Expand Up @@ -5393,22 +5394,20 @@ def shuffle(self, seed: int | None = None) -> Expr:
seed = random.randint(0, 10000)
return wrap_expr(self._pyexpr.shuffle(seed))

@deprecated_alias(fraction="frac")
def sample(
self,
n: int | None = None,
frac: float | None = None,
with_replacement: bool = False,
with_replacement: bool = True,
shuffle: bool = False,
seed: int | None = None,
n: int | None = None,
) -> Expr:
"""
Sample from this expression.
Parameters
----------
n
Number of items to return. Cannot be used with `frac`. Defaults to 1 if
`frac` is None.
frac
Fraction of items to return. Cannot be used with `n`.
with_replacement
Expand All @@ -5417,7 +5416,9 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
seed is used.
n
Number of items to return. Cannot be used with `frac`.
Examples
--------
Expand All @@ -5437,20 +5438,25 @@ def sample(
└─────┘
"""
warn(
"The function signature for Expr.sample will change in a future"
" version. Explicitly set `frac` and `with_replacement` using keyword"
" arguments to retain the same behaviour.",
FutureWarning,
stacklevel=2,
)

if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if seed is None:
seed = random.randint(0, 10000)

if frac is not None:
return wrap_expr(
self._pyexpr.sample_frac(frac, with_replacement, shuffle, seed)
)
if n is not None and frac is None:
return wrap_expr(self._pyexpr.sample_n(n, with_replacement, shuffle, seed))

if n is None:
n = 1
return wrap_expr(self._pyexpr.sample_n(n, with_replacement, shuffle, seed))
if frac is None:
frac = 1.0
return wrap_expr(
self._pyexpr.sample_frac(frac, with_replacement, shuffle, seed)
)

def ewm_mean(
self,
Expand Down
11 changes: 10 additions & 1 deletion py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3554,7 +3554,7 @@ def sample(
Shuffle the order of sampled data points.
seed
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
seed is used.
Examples
--------
Expand All @@ -3568,6 +3568,15 @@ def sample(
]
"""
if n is not None and frac is not None:
raise ValueError("cannot specify both `n` and `frac`")

if n is None and frac is not None:
return wrap_s(self._s.sample_frac(frac, with_replacement, shuffle, seed))

if n is None:
n = 1
return wrap_s(self._s.sample_n(n, with_replacement, shuffle, seed))

def peak_max(self) -> Series:
"""
Expand Down
28 changes: 28 additions & 0 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,34 @@ impl PySeries {
self.series.has_validity()
}

pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_n(n, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}

pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_frac(frac, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}

pub fn series_equal(&self, other: &PySeries, null_equal: bool, strict: bool) -> bool {
if strict {
self.series.eq(&other.series)
Expand Down
10 changes: 2 additions & 8 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import random
from typing import cast

import numpy as np
import pytest

import polars as pl
from polars.testing import assert_series_equal, verify_series_and_expr_api
Expand Down Expand Up @@ -94,6 +94,7 @@ def test_count_expr() -> None:
assert out["count"].to_list() == [4, 1]


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_sample() -> None:
a = pl.Series("a", range(0, 20))
out = pl.select(
Expand All @@ -108,13 +109,6 @@ def test_sample() -> None:
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)

# Setting random.seed should lead to reproducible results
random.seed(1)
result1 = pl.select(pl.lit(a).sample(n=10)).to_series()
random.seed(1)
result2 = pl.select(pl.lit(a).sample(n=10)).to_series()
assert result1.series_equal(result2)


def test_map_alias() -> None:
out = pl.DataFrame({"foo": [1, 2, 3]}).select(
Expand Down

0 comments on commit cae8e17

Please sign in to comment.