Skip to content

Commit

Permalink
ensure that Cast expressions first updates groups before it flattens (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 15, 2022
1 parent dd67372 commit 4ce0b50
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 2 deletions.
15 changes: 15 additions & 0 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,21 @@ impl DataType {
_ => true,
}
}
pub fn is_signed(&self) -> bool {
// allow because it cannot be replaced when object feature is activated
#[allow(clippy::match_like_matches_macro)]
match self {
#[cfg(feature = "dtype-i8")]
DataType::Int8 => true,
#[cfg(feature = "dtype-i16")]
DataType::Int16 => true,
DataType::Int32 | DataType::Int64 => true,
_ => false,
}
}
pub fn is_unsigned(&self) -> bool {
self.is_numeric() && !self.is_signed()
}

/// Convert to an Arrow data type.
pub fn to_arrow(&self) -> ArrowDataType {
Expand Down
10 changes: 8 additions & 2 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@ fn use_supertype(

// cast literal to right type
(AExpr::Literal(_), _) => {
st = type_right.clone();
// never cast signed to unsigned
if type_right.is_signed() {
st = type_right.clone();
}
}
// cast literal to left type
(_, AExpr::Literal(_)) => {
st = type_left.clone();
// never cast signed to unsigned
if type_left.is_signed() {
st = type_left.clone();
}
}
// do nothing
_ => {}
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ impl PhysicalExpr for CastExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
// before we flatten, make sure that groups are updated
ac.groups();
let s = ac.flat_naive();
let s = self.finish(s.as_ref())?;

Expand Down
27 changes: 27 additions & 0 deletions polars/tests/it/lazy/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,30 @@ fn test_expand_list() -> Result<()> {

Ok(())
}

#[test]
#[cfg(feature = "unique_counts")]
fn test_update_groups_in_cast() -> Result<()> {
let df = df![
"group" => ["A" ,"A", "A", "B", "B", "B", "B"],
"id"=> [1, 2, 1, 4, 5, 4, 6],
]?;

// optimized to
// col("id").unique_counts().cast(int64) * -1
// in aggregation that cast coerces a list and the cast may forget to update groups
let out = df
.lazy()
.groupby_stable([col("group")])
.agg([col("id").unique_counts() * lit(-1)])
.collect()?;

let expected = df![
"group" => ["A" ,"B"],
"id"=> [AnyValue::List(Series::new("", [-2i64, -1])), AnyValue::List(Series::new("", [-2i64, -1, -1]))]
]?;
dbg!(&out, &expected);

assert!(out.frame_equal(&expected));
Ok(())
}
23 changes: 23 additions & 0 deletions py-polars/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,26 @@ def test_overflow_uint16_agg_mean() -> None:
.to_dict(False)
== {"col1": ["A"], "col3": [64.0]}
)


def test_binary_on_list_agg_3345() -> None:
df = pl.DataFrame(
{
"group": ["A", "A", "A", "B", "B", "B", "B"],
"id": [1, 2, 1, 4, 5, 4, 6],
}
)

assert (
df.groupby(["group"], maintain_order=True)
.agg(
[
(
(pl.col("id").unique_counts() / pl.col("id").len()).log()
* -1
* (pl.col("id").unique_counts() / pl.col("id").len())
).sum()
]
)
.to_dict(False)
) == {"group": ["A", "B"], "id": [0.6365141682948128, 1.0397207708399179]}

0 comments on commit 4ce0b50

Please sign in to comment.