From be38947bce8e63b9502bcff6f3f06909d4550dfc Mon Sep 17 00:00:00 2001 From: Jacob Trueb Date: Sat, 26 Aug 2023 12:30:30 -0500 Subject: [PATCH] fix(rust): Prevent panic on sample_n with replacement from empty df (#10731) --- crates/polars-core/src/chunked_array/random.rs | 3 +++ py-polars/tests/unit/operations/test_random.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 046bc8fb5a08..74ece90d2c3e 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -9,6 +9,9 @@ use crate::random::get_global_random_u64; use crate::utils::{CustomIterTools, NoNull}; fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { + if len == 0 { + return IdxCa::new_vec("", vec![]); + } let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); let dist = Uniform::new(0, len as IdxSize); (0..n as IdxSize) diff --git a/py-polars/tests/unit/operations/test_random.py b/py-polars/tests/unit/operations/test_random.py index a92dfbe69677..9b9954d8e79c 100644 --- a/py-polars/tests/unit/operations/test_random.py +++ b/py-polars/tests/unit/operations/test_random.py @@ -54,6 +54,19 @@ def test_sample_df() -> None: assert df.sample(fraction=0.4, seed=0).shape == (1, 3) +def test_sample_empty_df() -> None: + df = pl.DataFrame({"foo": []}) + + # // If with replacement, then expect empty df + assert df.sample(n=3, with_replacement=True).shape == (0, 1) + assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1) + + # // If without replacement, then expect shape mismatch on sample_n not sample_frac + with pytest.raises(pl.ShapeError): + df.sample(n=3, with_replacement=False) + assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1) + + def test_sample_series() -> None: s = pl.Series("a", [1, 2, 3, 4, 5])