Skip to content

Commit

Permalink
feat[rust]: Optional random seeds (#4624)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 31, 2022
1 parent cae8e17 commit 8ea0b8e
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 11 deletions.
4 changes: 2 additions & 2 deletions polars/benches/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ fn add_benchmark(c: &mut Criterion) {
(10..=20).step_by(2).for_each(|log2_size| {
let size = 2usize.pow(log2_size);

let ca = Int32Chunked::init_rand(size, 0.0, 10);
let ca = Int32Chunked::init_rand(size, 0.0, Some(10));

c.bench_function(&format!("sort 2^{} i32", log2_size), |b| {
b.iter(|| bench_sort(&ca))
});

let ca = Int32Chunked::init_rand(size, 0.1, 10);
let ca = Int32Chunked::init_rand(size, 0.1, Some(10));
c.bench_function(&format!("sort null 2^{} i32", log2_size), |b| {
b.iter(|| bench_sort(&ca))
});
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/chunked_array/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ where
T: PolarsNumericType,
Standard: Distribution<T::Native>,
{
pub fn init_rand(size: usize, null_density: f32, seed: u64) -> Self {
let mut rng = SmallRng::seed_from_u64(seed);
pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));
(0..size)
.map(|_| {
if rng.gen::<f32>() < null_density {
Expand Down Expand Up @@ -102,10 +102,10 @@ impl Series {
self.sample_n(n, with_replacement, shuffle, seed)
}

pub fn shuffle(&self, seed: u64) -> Self {
pub fn shuffle(&self, seed: Option<u64>) -> Self {
let len = self.len();
let n = len;
let idx = create_rand_index_no_replacement(n, len, Some(seed), true);
let idx = create_rand_index_no_replacement(n, len, seed, true);
// Safety we know that we never go out of bounds
debug_assert_eq!(len, self.len());
unsafe { self.take_unchecked(&idx).unwrap() }
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2012,7 +2012,7 @@ impl Expr {
}

#[cfg(feature = "random")]
pub fn shuffle(self, seed: u64) -> Self {
pub fn shuffle(self, seed: Option<u64>) -> Self {
self.apply(move |s| Ok(s.shuffle(seed)), GetOutput::same_type())
.with_fmt("shuffle")
}
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5369,8 +5369,8 @@ def shuffle(self, seed: int | None = None) -> Expr:
Parameters
----------
seed
Seed initialization. If None given, the `random` module is used to generate
a random seed.
Seed for the random number generator. If set to None (default), a random
seed is generated using the ``random`` module.
Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4007,7 +4007,7 @@ def shuffle(self, seed: int | None = None) -> Series:
Parameters
----------
seed
Seed initialization.
Seed for the random number generator.
"""
if seed is None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1427,7 +1427,7 @@ impl PyExpr {
.into()
}

pub fn shuffle(&self, seed: u64) -> Self {
pub fn shuffle(&self, seed: Option<u64>) -> Self {
self.inner.clone().shuffle(seed).into()
}

Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import random
from typing import cast

import numpy as np
Expand Down Expand Up @@ -94,6 +95,16 @@ def test_count_expr() -> None:
assert out["count"].to_list() == [4, 1]


def test_shuffle() -> None:
# Setting random.seed should lead to reproducible results
s = pl.Series("a", range(20))
random.seed(1)
result1 = pl.select(pl.lit(s).shuffle()).to_series()
random.seed(1)
result2 = pl.select(pl.lit(s).shuffle()).to_series()
assert result1.series_equal(result2)


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_sample() -> None:
a = pl.Series("a", range(0, 20))
Expand Down

0 comments on commit 8ea0b8e

Please sign in to comment.