Skip to content

Commit

Permalink
fix(rust, python): fix streaming groupby aggregate types (#5636)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 26, 2022
1 parent 7360c66 commit cc83dd2
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 43 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl DataFrame {
let shape_err = |s: &[Series]| {
let msg = format!(
"Could not create a new DataFrame from Series. \
The Series have different lengths.\
The Series have different lengths. \
Got {:?}",
s
);
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,13 @@ impl<'a> AnyValueBuffer<'a> {
(Float32(builder), val) => builder.append_value(val.extract()?),
(Float64(builder), val) => builder.append_value(val.extract()?),
(Utf8(builder), AnyValue::Utf8(v)) => builder.append_value(v),
(Utf8(builder), AnyValue::Utf8Owned(v)) => builder.append_value(v),
(Utf8(builder), AnyValue::Null) => builder.append_null(),
// Struct and List can be recursive so use anyvalues for that
(All(_, vals), v) => vals.push(v),

// dynamic types
(Utf8(builder), av) => match av {
AnyValue::Utf8(v) => builder.append_value(v),
AnyValue::Int64(v) => builder.append_value(&format!("{}", v)),
AnyValue::Float64(v) => builder.append_value(&format!("{}", v)),
AnyValue::Boolean(true) => builder.append_value("true"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ where
}
let agg_fn = match logical_dtype.to_physical() {
dt if dt.is_integer() => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
// Boolean is aggregated as the IDX type.
DataType::Boolean => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
DataType::Float32 => AggregateFunction::MeanF32(MeanAgg::<f32>::new()),
DataType::Float64 => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
dt => AggregateFunction::Null(NullAgg::new(dt)),
Expand All @@ -156,12 +154,18 @@ where
AAggExpr::First(input) => {
let phys_expr = to_physical(*input, expr_arena).unwrap();
let dtype = phys_expr.field(schema).unwrap().dtype;
(phys_expr, AggregateFunction::First(FirstAgg::new(dtype)))
(
phys_expr,
AggregateFunction::First(FirstAgg::new(dtype.to_physical())),
)
}
AAggExpr::Last(input) => {
let phys_expr = to_physical(*input, expr_arena).unwrap();
let dtype = phys_expr.field(schema).unwrap().dtype;
(phys_expr, AggregateFunction::Last(LastAgg::new(dtype)))
(
phys_expr,
AggregateFunction::Last(LastAgg::new(dtype.to_physical())),
)
}
AAggExpr::Count(input) => {
let phys_expr = to_physical(*input, expr_arena).unwrap();
Expand Down
33 changes: 16 additions & 17 deletions polars/polars-lazy/polars-plan/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,16 @@ impl AExpr {
Ok(field)
}
Median(expr) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
if field.data_type() != &DataType::Utf8 {
field.coerce(DataType::Float64);
}
Ok(field)
}
Mean(expr) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
coerce_numeric_aggregation(&mut field);
Ok(field)
}
Expand All @@ -317,22 +319,26 @@ impl AExpr {
Ok(field)
}
Std(expr, _) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
coerce_numeric_aggregation(&mut field);
Ok(field)
}
Var(expr, _) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
coerce_numeric_aggregation(&mut field);
Ok(field)
}
NUnique(expr) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
field.coerce(DataType::UInt32);
Ok(field)
}
Count(expr) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
field.coerce(IDX_DTYPE);
Ok(field)
}
Expand All @@ -342,7 +348,8 @@ impl AExpr {
Ok(field)
}
Quantile { expr, .. } => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
coerce_numeric_aggregation(&mut field);
Ok(field)
}
Expand Down Expand Up @@ -394,15 +401,7 @@ impl AExpr {
}

fn coerce_numeric_aggregation(field: &mut Field) {
match field.dtype {
DataType::Duration(_) => {
// pass
}
DataType::Float32 => {
// pass
}
_ => {
field.coerce(DataType::Float64);
}
if field.dtype.is_numeric() && !matches!(&field.dtype, DataType::Float32) {
field.coerce(DataType::Float64)
}
}
11 changes: 8 additions & 3 deletions polars/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ pub fn create_physical_plan(
keys,
aggs,
apply,
schema: _schema,
schema: _output_schema,
maintain_order,
options,
} => {
Expand Down Expand Up @@ -372,17 +372,21 @@ pub fn create_physical_plan(
&& aggs.len() < 10
&& std::env::var("POLARS_NO_STREAMING_GROUPBY").is_err()
{
let key_dtype = _schema.get_index(0).unwrap().1.to_physical();
let key_dtype = _output_schema.get_index(0).unwrap().1.to_physical();
// only on numeric and string keys for now
let allowed_key = keys.len() == 1 && key_dtype.is_numeric()
|| matches!(key_dtype, DataType::Utf8);
let allowed_aggs = _output_schema.iter_dtypes().skip(1).all(|dtype| {
let dt = dtype.to_physical();
dt.is_numeric() || matches!(dt, DataType::Utf8 | DataType::Boolean)
});

let lp = Aggregate {
input,
keys,
aggs,
apply,
schema: _schema,
schema: _output_schema,
maintain_order,
options: options.clone(),
};
Expand All @@ -394,6 +398,7 @@ pub fn create_physical_plan(
.iter(root)
.any(|(_, lp)| matches!(lp, Join { .. }));
if allowed_key
&& allowed_aggs
&& !has_joins
&& insert_streaming_nodes(root, lp_arena, expr_arena, &mut vec![], false)?
{
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def describe_plan(self) -> str:
"""
return self._ldf.describe_plan()

@deprecated_alias(streaming="allow_streaming")
@deprecated_alias(allow_streaming="streaming")
def describe_optimized_plan(
self,
type_coercion: bool = True,
Expand All @@ -673,7 +673,7 @@ def describe_optimized_plan(

return ldf.describe_optimized_plan()

@deprecated_alias(streaming="allow_streaming")
@deprecated_alias(allow_streaming="streaming")
def show_graph(
self,
optimized: bool = True,
Expand Down
15 changes: 0 additions & 15 deletions py-polars/tests/unit/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from datetime import datetime, timedelta

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -221,17 +220,3 @@ def test_groupby_wildcard() -> None:
assert df.groupby([pl.col("*")], maintain_order=True).agg(
[pl.col("a").first().suffix("_agg")]
).to_dict(False) == {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]}


def test_streaming_non_streaming_gb() -> None:
n = 100
df = pl.DataFrame({"a": np.random.randint(0, 20, n)})
q = df.lazy().groupby("a").agg(pl.count()).sort("a")
assert q.collect(streaming=True).frame_equal(q.collect())

q = df.lazy().with_column(pl.col("a").cast(pl.Utf8))
q = q.groupby("a").agg(pl.count()).sort("a")
assert q.collect(streaming=True).frame_equal(q.collect())
q = df.lazy().with_column(pl.col("a").alias("b"))
q = q.groupby(["a", "b"]).agg(pl.count()).sort("a")
assert q.collect(streaming=True).frame_equal(q.collect())
86 changes: 86 additions & 0 deletions py-polars/tests/unit/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from datetime import date

import numpy as np

import polars as pl


def test_streaming_groupby_types() -> None:
df = pl.DataFrame(
{
"person_id": [1, 1],
"year": [1995, 1995],
"person_name": ["bob", "foo"],
"bool": [True, False],
"date": [date(2022, 1, 1), date(2022, 1, 1)],
}
)

for by in ["person_id", "year", "date", ["person_id", "year"]]:
out = (
(
df.lazy()
.groupby(by)
.agg(
[
pl.col("person_name").first().alias("str_first"),
pl.col("person_name").last().alias("str_last"),
pl.col("person_name").mean().alias("str_mean"),
pl.col("person_name").sum().alias("str_sum"),
pl.col("bool").first().alias("bool_first"),
pl.col("bool").last().alias("bool_last"),
pl.col("bool").mean().alias("bool_mean"),
pl.col("bool").sum().alias("bool_sum"),
pl.col("date").sum().alias("date_sum"),
pl.col("date").mean().alias("date_mean"),
pl.col("date").first().alias("date_first"),
pl.col("date").last().alias("date_last"),
]
)
)
.select(pl.all().exclude(by))
.collect(streaming=True)
)
assert out.schema == {
"str_first": pl.Utf8,
"str_last": pl.Utf8,
"str_mean": pl.Utf8,
"str_sum": pl.Utf8,
"bool_first": pl.Boolean,
"bool_last": pl.Boolean,
"bool_mean": pl.Boolean,
"bool_sum": pl.UInt32,
"date_sum": pl.Date,
"date_mean": pl.Date,
"date_first": pl.Date,
"date_last": pl.Date,
}

assert out.to_dict(False) == {
"str_first": ["bob"],
"str_last": ["foo"],
"str_mean": [None],
"str_sum": [None],
"bool_first": [True],
"bool_last": [False],
"bool_mean": [None],
"bool_sum": [1],
"date_sum": [date(2074, 1, 1)],
"date_mean": [date(2022, 1, 1)],
"date_first": [date(2022, 1, 1)],
"date_last": [date(2022, 1, 1)],
}


def test_streaming_non_streaming_gb() -> None:
n = 100
df = pl.DataFrame({"a": np.random.randint(0, 20, n)})
q = df.lazy().groupby("a").agg(pl.count()).sort("a")
assert q.collect(streaming=True).frame_equal(q.collect())

q = df.lazy().with_column(pl.col("a").cast(pl.Utf8))
q = q.groupby("a").agg(pl.count()).sort("a")
assert q.collect(streaming=True).frame_equal(q.collect())
q = df.lazy().with_column(pl.col("a").alias("b"))
q = q.groupby(["a", "b"]).agg(pl.count()).sort("a")
assert q.collect(streaming=True).frame_equal(q.collect())

0 comments on commit cc83dd2

Please sign in to comment.