Skip to content

Commit

Permalink
fix(rust): Prevent panic on sample_n with replacement from empty df (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
trueb2 committed Aug 26, 2023
1 parent 2b200f0 commit be38947
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
3 changes: 3 additions & 0 deletions crates/polars-core/src/chunked_array/random.rs
Expand Up @@ -9,6 +9,9 @@ use crate::random::get_global_random_u64;
use crate::utils::{CustomIterTools, NoNull};

fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
if len == 0 {
return IdxCa::new_vec("", vec![]);
}
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
let dist = Uniform::new(0, len as IdxSize);
(0..n as IdxSize)
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/operations/test_random.py
Expand Up @@ -54,6 +54,19 @@ def test_sample_df() -> None:
assert df.sample(fraction=0.4, seed=0).shape == (1, 3)


def test_sample_empty_df() -> None:
df = pl.DataFrame({"foo": []})

# // If with replacement, then expect empty df
assert df.sample(n=3, with_replacement=True).shape == (0, 1)
assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1)

# // If without replacement, then expect shape mismatch on sample_n not sample_frac
with pytest.raises(pl.ShapeError):
df.sample(n=3, with_replacement=False)
assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1)


def test_sample_series() -> None:
s = pl.Series("a", [1, 2, 3, 4, 5])

Expand Down

0 comments on commit be38947

Please sign in to comment.