Skip to content

Commit

Permalink
fix[rust]: fix slice pushdown in rolling/dynamic groupby (#4542)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 23, 2022
1 parent 520b3b7 commit 5717543
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 21 deletions.
38 changes: 36 additions & 2 deletions polars/polars-core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
use std::borrow::Cow;
use std::fmt::{Display, Formatter};

use anyhow::Error;
use thiserror::Error as ThisError;

type ErrString = Cow<'static, str>;
#[derive(Debug)]
pub enum ErrString {
Owned(String),
Borrowed(&'static str),
}

impl From<&'static str> for ErrString {
fn from(msg: &'static str) -> Self {
if std::env::var("POLARS_PANIC_ON_ERR").is_ok() {
panic!("{}", msg)
} else {
ErrString::Borrowed(msg)
}
}
}

impl From<String> for ErrString {
fn from(msg: String) -> Self {
if std::env::var("POLARS_PANIC_ON_ERR").is_ok() {
panic!("{}", msg)
} else {
ErrString::Owned(msg)
}
}
}

impl Display for ErrString {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let msg = match self {
ErrString::Owned(msg) => msg.as_str(),
ErrString::Borrowed(msg) => msg,
};
write!(f, "{}", msg)
}
}

#[derive(Debug, ThisError)]
pub enum PolarsError {
Expand Down
17 changes: 10 additions & 7 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,14 @@ impl DataFrame {
pub fn new<S: IntoSeries>(columns: Vec<S>) -> Result<Self> {
let mut first_len = None;

let shape_err = || {
Err(PolarsError::ShapeMisMatch(
"Could not create a new DataFrame from Series. The Series have different lengths"
.into(),
))
let shape_err = |s: &[Series]| {
let msg = format!(
"Could not create a new DataFrame from Series. \
The Series have different lengths.\
Got {:?}",
s
);
Err(PolarsError::ShapeMisMatch(msg.into()))
};

let series_cols = if S::is_series() {
Expand All @@ -244,7 +247,7 @@ impl DataFrame {
match first_len {
Some(len) => {
if s.len() != len {
return shape_err();
return shape_err(&series_cols);
}
}
None => first_len = Some(s.len()),
Expand All @@ -271,7 +274,7 @@ impl DataFrame {
match first_len {
Some(len) => {
if series.len() != len {
return shape_err();
return shape_err(&series_cols);
}
}
None => first_len = Some(series.len()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl Executor for GroupByDynamicExec {
.map(|e| e.evaluate(&df, state))
.collect::<Result<Vec<_>>>()?;

let (time_key, keys, groups) = df.groupby_dynamic(keys, &self.options)?;
let (mut time_key, mut keys, groups) = df.groupby_dynamic(keys, &self.options)?;

let mut groups = &groups;
#[allow(unused_assignments)]
Expand All @@ -45,6 +45,14 @@ impl Executor for GroupByDynamicExec {
if let Some((offset, len)) = self.slice {
sliced_groups = Some(groups.slice(offset, len));
groups = sliced_groups.as_deref().unwrap();

time_key = time_key.slice(offset, len);

// todo! optimize this, we can prevent an agg_first aggregation upstream
// the ordering has changed due to the groupby
for key in keys.iter_mut() {
*key = key.slice(offset, len)
}
}

let agg_columns = POOL.install(|| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl Executor for GroupByRollingExec {
.map(|e| e.evaluate(&df, state))
.collect::<Result<Vec<_>>>()?;

let (time_key, keys, groups) = df.groupby_rolling(keys, &self.options)?;
let (mut time_key, mut keys, groups) = df.groupby_rolling(keys, &self.options)?;

let mut groups = &groups;
#[allow(unused_assignments)]
Expand All @@ -44,8 +44,19 @@ impl Executor for GroupByRollingExec {
if let Some((offset, len)) = self.slice {
sliced_groups = Some(groups.slice(offset, len));
groups = sliced_groups.as_deref().unwrap();

time_key = time_key.slice(offset, len);
}

// the ordering has changed due to the groupby
if !keys.is_empty() {
unsafe {
for key in keys.iter_mut() {
*key = key.agg_first(groups);
}
}
};

let agg_columns = POOL.install(|| {
self.aggs
.par_iter()
Expand Down
11 changes: 1 addition & 10 deletions polars/polars-time/src/groupby/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ impl Wrap<&DataFrame> {
fn impl_groupby_rolling(
&self,
dt: Series,
mut by: Vec<Series>,
by: Vec<Series>,
options: &RollingGroupOptions,
tu: TimeUnit,
time_type: &DataType,
Expand Down Expand Up @@ -386,15 +386,6 @@ impl Wrap<&DataFrame> {
};
let dt = dt.cast(time_type).unwrap();

// // the ordering has changed due to the groupby
if !by.is_empty() {
unsafe {
for key in by.iter_mut() {
*key = key.agg_first(&groups);
}
}
};

Ok((dt, by, groups))
}
}
Expand Down
1 change: 1 addition & 0 deletions polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@
//! * `POLARS_ALLOW_EXTENSION` -> allows for `[ObjectChunked<T>]` to be used in arrow, opening up possibilities like using
//! `T` in complex lazy expressions. However this does require `unsafe` code allow this.
//! * `POLARS_NO_PARQUET_STATISTICS` -> if set, statistics in parquet files are ignored.
//! * `POLARS_PANIC_ON_ERR` -> panic instead of returning an Error..
//!
//! ## User Guide
//! If you want to read more, [check the User Guide](https://pola-rs.github.io/polars-book/).
Expand Down
48 changes: 48 additions & 0 deletions py-polars/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,51 @@ def test_rolling_groupby_extrema() -> None:
"col1_min": [3, 3, 3, 4, 2, 1, 0],
"col1_max": [3, 4, 5, 6, 6, 6, 2],
}


def test_rolling_slice_pushdown() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "a", "b"], "c": [1, 3, 5]}).lazy()
df = (
df.sort("a")
.groupby_rolling(
"a",
by="b",
period="2i",
)
.agg(
[
(pl.col("c") - pl.col("c").shift_and_fill(1, fill_value=0))
.sum()
.alias("c")
]
)
)
assert df.head(2).collect().to_dict(False) == {
"b": ["a", "a"],
"a": [1, 2],
"c": [1, 3],
}


def test_groupby_dynamic_slice_pushdown() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "a", "b"], "c": [1, 3, 5]}).lazy()
df = (
df.sort("a")
.groupby_dynamic(
"a",
by="b",
every="2i",
)
.agg(
[
(pl.col("c") - pl.col("c").shift_and_fill(1, fill_value=0))
.sum()
.alias("c")
]
)
)
assert df.head(2).collect().to_dict(False) == {
"b": ["a", "a"],
"a": [0, 2],
"c": [1, 3],
}

0 comments on commit 5717543

Please sign in to comment.