Skip to content

Commit

Permalink
fix branch supertypes (#3683)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 13, 2022
1 parent 0b6cdf5 commit 117e606
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 92 deletions.
282 changes: 199 additions & 83 deletions polars/polars-core/src/utils/mod.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion polars/polars-lazy/src/logical_plan/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl LiteralValue {
UInt16(v) => AnyValue::UInt16(*v),
UInt32(v) => AnyValue::UInt32(*v),
UInt64(v) => AnyValue::UInt64(*v),
#[cfg(feature = "dtype-i16")]
#[cfg(feature = "dtype-i8")]
Int8(v) => AnyValue::Int8(*v),
#[cfg(feature = "dtype-i16")]
Int16(v) => AnyValue::Int16(*v),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ impl OptimizationRule for SimplifyExprRule {
Some(AExpr::Literal(LiteralValue::Float64(*v as f64)))
}

#[cfg(feature = "dtype-i16")]
#[cfg(feature = "dtype-i8")]
(AExpr::Literal(LiteralValue::Int8(v)), DataType::Float64) => {
Some(AExpr::Literal(LiteralValue::Float64(*v as f64)))
}
Expand Down Expand Up @@ -430,7 +430,7 @@ impl OptimizationRule for SimplifyExprRule {
Some(AExpr::Literal(LiteralValue::Float64(*v as f64)))
}

#[cfg(feature = "dtype-i16")]
#[cfg(feature = "dtype-i8")]
(AExpr::Literal(LiteralValue::Int8(v)), DataType::Float32) => {
Some(AExpr::Literal(LiteralValue::Float32(*v as f32)))
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl PartitionedAggregation for CastExpr {
state: &ExecutionState,
) -> Result<Series> {
let e = self.input.as_partitioned_aggregator().unwrap();
e.evaluate_partitioned(df, groups, state)
self.finish(&e.evaluate_partitioned(df, groups, state)?)
}

fn finalize(
Expand Down
5 changes: 3 additions & 2 deletions polars/polars-lazy/src/tests/optimization_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ pub fn test_slice_pushdown_sort() -> Result<()> {
}

#[test]
#[cfg(feature = "dtype-i16")]
pub fn test_predicate_block_cast() -> Result<()> {
let df = df![
"value" => [10, 20, 30, 40]
Expand All @@ -253,12 +254,12 @@ pub fn test_predicate_block_cast() -> Result<()> {
let lf1 = df
.clone()
.lazy()
.with_column(col("value") * lit(0.1f32))
.with_column(col("value").cast(DataType::Int16) * lit(0.1f32))
.filter(col("value").lt(lit(2.5f32)));

let lf2 = df
.lazy()
.select([col("value") * lit(0.1f32)])
.select([col("value").cast(DataType::Int16) * lit(0.1f32)])
.filter(col("value").lt(lit(2.5f32)));

for lf in [lf1, lf2] {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2053,7 +2053,7 @@ fn test_partitioned_gb_binary() -> Result<()> {
let out = df
.lazy()
.groupby([col("col")])
.agg([(col("col").cast(DataType::Float32) + lit(10))
.agg([(col("col").cast(DataType::Float32) + lit(10.0))
.sum()
.alias("sum")])
.collect()?;
Expand Down
3 changes: 2 additions & 1 deletion polars/tests/it/lazy/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ fn test_exploded_window_function() -> Result<()> {
])
.collect()?;

// even though we fill with f32, cast i32 -> f32 can overflow so the result is f64
assert_eq!(
Vec::from(out.column("shifted")?.f32()?),
Vec::from(out.column("shifted")?.f64()?),
&[Some(-1.0), Some(3.0), Some(-1.0), Some(5.0), Some(4.0)]
);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_list_eval_dtype_inference() -> None:
}
)

rank_pct = pl.col("").rank(reverse=True) / pl.col("").count()
rank_pct = pl.col("").rank(reverse=True) / pl.col("").count().cast(pl.UInt16)

# the .arr.first() would fail if .arr.eval did not correctly infer the output type
assert grades.with_column(
Expand Down

0 comments on commit 117e606

Please sign in to comment.