Skip to content

Commit

Permalink
fix regressions
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 9, 2021
1 parent 46c98e2 commit 672fe28
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 15 deletions.
3 changes: 2 additions & 1 deletion polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ impl<'a> AnyValue<'a> {
AnyValue::Int32(v) => AnyValue::Date(v),
#[cfg(feature = "dtype-datetime")]
AnyValue::Int64(v) => AnyValue::Datetime(v),
_ => panic!("cannot create date from other type"),
AnyValue::Null => AnyValue::Null,
dt => panic!("cannot create date from other type. dtype: {}", dt),
}
}

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ impl Debug for Series {
limit,
f,
self.date().unwrap(),
"Date",
"date",
self.name(),
"Series"
),
Expand Down
34 changes: 33 additions & 1 deletion polars/polars-core/src/series/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ macro_rules! impl_compare {
.unwrap()
.$method(rhs.datetime().unwrap().deref()),
DataType::List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()),
#[cfg(feature = "dtype-categorical")]
DataType::Categorical => lhs
.categorical()
.unwrap()
.$method(rhs.categorical().unwrap().deref()),

_ => unimplemented!(),
}
}};
Expand Down Expand Up @@ -76,7 +82,33 @@ where

impl ChunkCompare<&Series> for Series {
fn eq_missing(&self, rhs: &Series) -> BooleanChunked {
impl_compare!(self, rhs, eq_missing)
#[cfg(feature = "dtype-categorical")]
use DataType::*;
match (self.dtype(), rhs.dtype(), self.len(), rhs.len()) {
#[cfg(feature = "dtype-categorical")]
(Categorical, Utf8, _, 1) => {
return compare_cat_to_str_series(
self,
rhs,
self.name(),
|s, idx| s.eq_missing(idx),
false,
);
}
#[cfg(feature = "dtype-categorical")]
(Utf8, Categorical, 1, _) => {
return compare_cat_to_str_series(
rhs,
self,
self.name(),
|s, idx| s.eq_missing(idx),
false,
);
}
_ => {
impl_compare!(self, rhs, eq_missing)
}
}
}

/// Create a boolean mask by checking for equality.
Expand Down
14 changes: 14 additions & 0 deletions polars/polars-io/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ where
fn finish(self, df: &DataFrame) -> Result<()> {
let mut writer = self.writer_builder.from_writer(self.buffer);

// temp coerce cat to utf8 until supported in csv writer
let columns = df
.get_columns()
.iter()
.map(|s| {
if let DataType::Categorical = s.dtype() {
s.cast(&DataType::Utf8).unwrap()
} else {
s.clone()
}
})
.collect();
let df = DataFrame::new_no_checks(columns);

let iter = df.iter_record_batches();
write::write_header(&mut writer, &df.schema().to_arrow())?;
for batch in iter {
Expand Down
10 changes: 1 addition & 9 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use rayon::prelude::*;
use std::borrow::Cow;
use std::sync::Arc;

pub struct ApplyExpr {
Expand Down Expand Up @@ -44,14 +43,7 @@ impl PhysicalExpr for ApplyExpr {
"function with multiple inputs not yet supported in aggregation context".into(),
));
}
let mut ac = if let Ok(ae) = self.inputs[0].as_agg_expr() {
AggregationContext::new(
ae.aggregate(df, groups, state)?.unwrap(),
Cow::Borrowed(groups),
)
} else {
self.inputs[0].evaluate_on_groups(df, groups, state)?
};
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;

match self.collect_groups {
ApplyOptions::ApplyGroups => {
Expand Down
12 changes: 10 additions & 2 deletions polars/polars-lazy/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,16 @@ impl DefaultPlanner {

for agg in &aggs {
// make sure that we don't have a binary expr in the expr tree
let matches =
|e: &AExpr| matches!(e, AExpr::SortBy { .. } | AExpr::Filter { .. });
let matches = |e: &AExpr| {
matches!(
e,
AExpr::SortBy { .. }
| AExpr::Filter { .. }
| AExpr::BinaryExpr { .. }
| AExpr::BinaryFunction { .. }
| AExpr::Function { .. }
)
};
if aexpr_to_root_nodes(*agg, expr_arena).len() != 1
|| has_aexpr(*agg, expr_arena, matches)
{
Expand Down
20 changes: 20 additions & 0 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1899,3 +1899,23 @@ fn test_fill_nan() -> Result<()> {

Ok(())
}

#[test]
fn test_agg_exprs() -> Result<()> {
let df = fruits_cars();

// a binary expression followed by a function and an aggregation. See if it runs
let out = df
.lazy()
.stable_groupby([col("cars")])
.agg([(lit(1) - col("A"))
.map(|s| Ok(&s * 2), GetOutput::same_type())
.list()
.alias("foo")])
.collect()?;
let ca = out.column("foo")?.list()?;
let out = ca.lst_lengths();

assert_eq!(Vec::from(&out), &[Some(4), Some(1)]);
Ok(())
}
6 changes: 5 additions & 1 deletion py-polars/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def df() -> pl.DataFrame:
}
)
return df.with_columns(
[pl.col("date").cast(pl.Date), pl.col("datetime").cast(pl.Datetime)]
[
pl.col("date").cast(pl.Date),
pl.col("datetime").cast(pl.Datetime),
pl.col("strings").cast(pl.Categorical).alias("cat"),
]
)


Expand Down
2 changes: 2 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ def test_concat():


def test_to_pandas(df):
# pyarrow cannot deal with unsigned dictionary integer yet.
df = df.drop("cat")
df.to_arrow()
df.to_pandas()
# test shifted df
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

def test_to_from_buffer(df):
df = df.drop("strings_nulls")
df

for to_fn, from_fn in zip(
[df.to_parquet, df.to_csv], [pl.read_parquet, pl.read_csv]
Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ def test_from_pydatetime():
assert s.name == "name"
assert s.null_count() == 1
assert s.dt[0] == dates[0]
# fmt dates and nulls
print(s)

dates = [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3), None]
s = pl.Series("name", dates)
Expand All @@ -466,6 +468,9 @@ def test_from_pydatetime():
assert s.null_count() == 1
assert s.dt[0] == dates[0]

# fmt dates and nulls
print(s)


def test_from_pandas_nan_to_none():
from pyarrow import ArrowInvalid
Expand Down

0 comments on commit 672fe28

Please sign in to comment.