Skip to content

Commit

Permalink
fix[rust, python]: accept schema in lazy groupby apply (#4756)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 7, 2022
1 parent bde537f commit e46ce2d
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 32 deletions.
13 changes: 0 additions & 13 deletions polars/polars-core/src/frame/groupby/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1146,19 +1146,6 @@ mod test {
);
}

#[test]
#[cfg_attr(miri, ignore)]
fn test_groupby_apply() {
let df = df! {
"a" => [1, 1, 2, 2, 2],
"b" => [1, 2, 3, 4, 5]
}
.unwrap();

let out = df.groupby(["a"]).unwrap().apply(Ok).unwrap();
assert!(out.sort(["b"], false).unwrap().frame_equal(&df));
}

#[test]
#[cfg_attr(miri, ignore)]
fn test_groupby_threaded() {
Expand Down
29 changes: 16 additions & 13 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ impl LazyFrame {
/// Utility struct for lazy groupby operation.
#[derive(Clone)]
pub struct LazyGroupBy {
pub(crate) logical_plan: LogicalPlan,
pub logical_plan: LogicalPlan,
opt_state: OptState,
keys: Vec<Expr>,
maintain_order: bool,
Expand Down Expand Up @@ -1386,21 +1386,24 @@ impl LazyGroupBy {
}

/// Apply a function over the groups as a new `DataFrame`. It is not recommended that you use
/// this as materializing the `DataFrame` is quite expensive.
pub fn apply<F>(self, f: F) -> LazyFrame
/// this as materializing the `DataFrame` is very expensive.
pub fn apply<F>(self, f: F, schema: SchemaRef) -> LazyFrame
where
F: 'static + Fn(DataFrame) -> Result<DataFrame> + Send + Sync,
{
let lp = LogicalPlanBuilder::from(self.logical_plan)
.groupby(
Arc::new(self.keys),
vec![],
Some(Arc::new(f)),
self.maintain_order,
None,
None,
)
.build();
let lp = LogicalPlan::Aggregate {
input: Box::new(self.logical_plan),
keys: Arc::new(self.keys),
aggs: vec![],
schema,
apply: Some(Arc::new(f)),
maintain_order: self.maintain_order,
options: GroupbyOptions {
dynamic: None,
rolling: None,
slice: None,
},
};
LazyFrame::from_logical_plan(lp, self.opt_state)
}
}
Expand Down
13 changes: 11 additions & 2 deletions py-polars/polars/internals/lazyframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Generic, Sequence, TypeVar

import polars.internals as pli
from polars.datatypes import Schema
from polars.internals.expr import ensure_list_of_pyexpr
from polars.utils import is_expr_sequence

Expand Down Expand Up @@ -170,7 +171,9 @@ def tail(self, n: int = 5) -> LDF:
"""
return self._lazyframe_class._from_pyldf(self.lgb.tail(n))

def apply(self, f: Callable[[pli.DataFrame], pli.DataFrame]) -> LDF:
def apply(
self, f: Callable[[pli.DataFrame], pli.DataFrame], schema: Schema | None
) -> LDF:
"""
Apply a function over the groups as a new `DataFrame`.
Expand All @@ -189,6 +192,12 @@ def apply(self, f: Callable[[pli.DataFrame], pli.DataFrame]) -> LDF:
----------
f
Function to apply over each group of the `LazyFrame`.
schema
Schema of the output function. This has to be known statically.
If the schema provided is incorrect, this is a bug in the callers
query and may lead to errors.
If none given, polars assumes the schema is unchanged.
Examples
--------
Expand Down Expand Up @@ -232,4 +241,4 @@ def apply(self, f: Callable[[pli.DataFrame], pli.DataFrame]) -> LDF:
... ) # doctest: +IGNORE_RESULT
"""
return self._lazyframe_class._from_pyldf(self.lgb.apply(f))
return self._lazyframe_class._from_pyldf(self.lgb.apply(f, schema))
14 changes: 12 additions & 2 deletions py-polars/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,18 @@ impl PyLazyGroupBy {
lgb.tail(Some(n)).into()
}

pub fn apply(&mut self, lambda: PyObject) -> PyLazyFrame {
pub fn apply(
&mut self,
lambda: PyObject,
schema: Option<Wrap<Schema>>,
) -> PyResult<PyLazyFrame> {
let lgb = self.lgb.take().unwrap();
let schema = match schema {
Some(schema) => Arc::new(schema.0),
None => LazyFrame::from(lgb.logical_plan.clone())
.schema()
.map_err(PyPolarsErr::from)?,
};

let function = move |df: DataFrame| {
Python::with_gil(|py| {
Expand Down Expand Up @@ -82,7 +92,7 @@ impl PyLazyGroupBy {
Ok(pydf.df)
})
};
lgb.apply(function).into()
Ok(lgb.apply(function, schema).into())
}
}

Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ def test_or() -> None:

def test_groupby_apply() -> None:
df = pl.DataFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 3.0]})
ldf = df.lazy().groupby("a").apply(lambda df: df)
assert ldf.collect().sort("b").frame_equal(df)
ldf = (
df.lazy()
.groupby("a")
.apply(lambda df: df * 2.0, schema={"a": pl.Float64, "b": pl.Float64})
)
out = ldf.collect()
assert out.schema == ldf.schema
assert out.shape == (3, 2)


def test_filter_str() -> None:
Expand Down

0 comments on commit e46ce2d

Please sign in to comment.