Skip to content

Commit

Permalink
shuffle sample option (#3308)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 5, 2022
1 parent 49149a6 commit b97ba6f
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 33 deletions.
55 changes: 38 additions & 17 deletions polars/polars-core/src/chunked_array/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ where
}

impl Series {
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> Result<Self> {
pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Result<Self> {
if !with_replacement && n > self.len() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
Expand All @@ -71,7 +77,7 @@ impl Series {
unsafe { self.take_unchecked(&idx) }
}
false => {
let idx = create_rand_index_no_replacement(n, len, seed, false);
let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
// Safety we know that we never go out of bounds
debug_assert_eq!(len, self.len());
unsafe { self.take_unchecked(&idx) }
Expand All @@ -84,10 +90,11 @@ impl Series {
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Result<Self> {
let n = (self.len() as f64 * frac) as usize;
self.sample_n(n, with_replacement, seed)
self.sample_n(n, with_replacement, shuffle, seed)
}

pub fn shuffle(&self, seed: u64) -> Self {
Expand All @@ -105,7 +112,13 @@ where
ChunkedArray<T>: ChunkTake,
{
/// Sample n datapoints from this ChunkedArray.
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> Result<Self> {
pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Result<Self> {
if !with_replacement && n > self.len() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
Expand All @@ -121,7 +134,7 @@ where
unsafe { Ok(self.take_unchecked((&idx).into())) }
}
false => {
let idx = create_rand_index_no_replacement(n, len, seed, false);
let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
// Safety we know that we never go out of bounds
debug_assert_eq!(len, self.len());
unsafe { Ok(self.take_unchecked((&idx).into())) }
Expand All @@ -134,16 +147,23 @@ where
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Result<Self> {
let n = (self.len() as f64 * frac) as usize;
self.sample_n(n, with_replacement, seed)
self.sample_n(n, with_replacement, shuffle, seed)
}
}

impl DataFrame {
/// Sample n datapoints from this DataFrame.
pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> Result<Self> {
pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Result<Self> {
if !with_replacement && n > self.height() {
return Err(PolarsError::ShapeMisMatch(
"n is larger than the number of elements in this array".into(),
Expand All @@ -152,7 +172,7 @@ impl DataFrame {
// all columns should used the same indices. So we first create the indices.
let idx = match with_replacement {
true => create_rand_index_with_replacement(n, self.height(), seed),
false => create_rand_index_no_replacement(n, self.height(), seed, false),
false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
};
// Safety:
// indices are within bounds
Expand All @@ -164,10 +184,11 @@ impl DataFrame {
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Result<Self> {
let n = (self.height() as f64 * frac) as usize;
self.sample_n(n, with_replacement, seed)
self.sample_n(n, with_replacement, shuffle, seed)
}
}

Expand Down Expand Up @@ -247,16 +268,16 @@ mod test {
.unwrap();

// default samples are random and don't require seeds
assert!(df.sample_n(3, false, None).is_ok());
assert!(df.sample_frac(0.4, false, None).is_ok());
assert!(df.sample_n(3, false, false, None).is_ok());
assert!(df.sample_frac(0.4, false, false, None).is_ok());
// with seeding
assert!(df.sample_n(3, false, Some(0)).is_ok());
assert!(df.sample_frac(0.4, false, Some(0)).is_ok());
assert!(df.sample_n(3, false, false, Some(0)).is_ok());
assert!(df.sample_frac(0.4, false, false, Some(0)).is_ok());
// without replacement can not sample more than 100%
assert!(df.sample_frac(2.0, false, Some(0)).is_err());
assert!(df.sample_n(3, true, Some(0)).is_ok());
assert!(df.sample_frac(0.4, true, Some(0)).is_ok());
assert!(df.sample_frac(2.0, false, false, Some(0)).is_err());
assert!(df.sample_n(3, true, false, Some(0)).is_ok());
assert!(df.sample_frac(0.4, true, false, Some(0)).is_ok());
// with replacement can sample more than 100%
assert!(df.sample_frac(2.0, true, Some(0)).is_ok());
assert!(df.sample_frac(2.0, true, false, Some(0)).is_ok());
}
}
10 changes: 8 additions & 2 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1636,9 +1636,15 @@ impl Expr {
}

#[cfg(feature = "random")]
pub fn sample_frac(self, frac: f64, with_replacement: bool, seed: Option<u64>) -> Self {
pub fn sample_frac(
self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.apply(
move |s| s.sample_frac(frac, with_replacement, seed),
move |s| s.sample_frac(frac, with_replacement, shuffle, seed),
GetOutput::same_type(),
)
.with_fmt("shuffle")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fn estimate_unique_count(keys: &[Series], mut sample_size: usize) -> usize {
if keys.len() == 1 {
// we sample as that will work also with sorted data.
// not that sampling without replacement is very very expensive. don't do that.
let s = keys[0].sample_n(sample_size, true, None).unwrap();
let s = keys[0].sample_n(sample_size, true, false, None).unwrap();
// fast multi-threaded way to get unique.
let groups = s.group_tuples(true, false);
finish(&groups)
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2645,6 +2645,7 @@ def sample(
self,
fraction: float = 1.0,
with_replacement: bool = True,
shuffle: bool = False,
seed: Optional[int] = None,
) -> "Expr":
"""
Expand All @@ -2658,8 +2659,12 @@ def sample(
Allow values to be sampled more than once.
seed
Seed initialization. If None given a random seed is used.
shuffle
Shuffle the order of sampled data points.
"""
return wrap_expr(self._pyexpr.sample_frac(fraction, with_replacement, seed))
return wrap_expr(
self._pyexpr.sample_frac(fraction, with_replacement, shuffle, seed)
)

def ewm_mean(
self,
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4892,6 +4892,7 @@ def sample(
n: Optional[int] = None,
frac: Optional[float] = None,
with_replacement: bool = False,
shuffle: bool = False,
seed: Optional[int] = None,
) -> DF:
"""
Expand All @@ -4905,6 +4906,8 @@ def sample(
Fraction between 0.0 and 1.0 .
with_replacement
Sample with replacement.
shuffle
Shuffle the order of sampled data points.
seed
Initialization seed. If None is given a random seed is used.
Expand Down Expand Up @@ -4934,12 +4937,14 @@ def sample(
raise ValueError("n and frac were both supplied")

if n is None and frac is not None:
return self._from_pydf(self._df.sample_frac(frac, with_replacement, seed))
return self._from_pydf(
self._df.sample_frac(frac, with_replacement, shuffle, seed)
)

if n is None:
n = 1

return self._from_pydf(self._df.sample_n(n, with_replacement, seed))
return self._from_pydf(self._df.sample_n(n, with_replacement, shuffle, seed))

def fold(
self, operation: Callable[["pli.Series", "pli.Series"], "pli.Series"]
Expand Down
7 changes: 5 additions & 2 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3053,6 +3053,7 @@ def sample(
n: Optional[int] = None,
frac: Optional[float] = None,
with_replacement: bool = False,
shuffle: bool = False,
seed: Optional[int] = None,
) -> "Series":
"""
Expand All @@ -3066,6 +3067,8 @@ def sample(
Fraction between 0.0 and 1.0 .
with_replacement
sample with replacement.
shuffle
Shuffle the order of sampled data points.
seed
Initialization seed. If None is given a random seed is used.
Expand All @@ -3085,12 +3088,12 @@ def sample(
raise ValueError("n and frac were both supplied")

if n is None and frac is not None:
return wrap_s(self._s.sample_frac(frac, with_replacement, seed))
return wrap_s(self._s.sample_frac(frac, with_replacement, shuffle, seed))

if n is None:
n = 1

return wrap_s(self._s.sample_n(n, with_replacement, seed))
return wrap_s(self._s.sample_n(n, with_replacement, shuffle, seed))

def peak_max(self) -> "Series":
"""
Expand Down
13 changes: 10 additions & 3 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,10 +665,16 @@ impl PyDataFrame {
Ok(df.into())
}

pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> PyResult<Self> {
pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let df = self
.df
.sample_n(n, with_replacement, seed)
.sample_n(n, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(df.into())
}
Expand All @@ -677,11 +683,12 @@ impl PyDataFrame {
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let df = self
.df
.sample_frac(frac, with_replacement, seed)
.sample_frac(frac, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(df.into())
}
Expand Down
10 changes: 8 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,10 +1266,16 @@ impl PyExpr {
self.inner.clone().shuffle(seed).into()
}

pub fn sample_frac(&self, frac: f64, with_replacement: bool, seed: Option<u64>) -> Self {
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.inner
.clone()
.sample_frac(frac, with_replacement, seed)
.sample_frac(frac, with_replacement, shuffle, seed)
.into()
}

Expand Down
13 changes: 10 additions & 3 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,16 @@ impl PySeries {
Ok(ca.into_series().into())
}

pub fn sample_n(&self, n: usize, with_replacement: bool, seed: Option<u64>) -> PyResult<Self> {
pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_n(n, with_replacement, seed)
.sample_n(n, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}
Expand All @@ -583,11 +589,12 @@ impl PySeries {
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let s = self
.series
.sample_frac(frac, with_replacement, seed)
.sample_frac(frac, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(s.into())
}
Expand Down

0 comments on commit b97ba6f

Please sign in to comment.