Skip to content

Commit

Permalink
allow complex expr expansion in filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 28, 2021
1 parent 48ea911 commit c6611bc
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 27 deletions.
23 changes: 9 additions & 14 deletions polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,26 +220,21 @@ impl<T> ChunkedArray<T> {
pub(crate) unsafe fn unpack_series_matching_physical_type(
&self,
series: &Series,
) -> Result<&ChunkedArray<T>> {
) -> &ChunkedArray<T> {
let series_trait = &**series;
if self.dtype() == series.dtype() {
let ca = &*(series_trait as *const dyn SeriesTrait as *const ChunkedArray<T>);
Ok(ca)
&*(series_trait as *const dyn SeriesTrait as *const ChunkedArray<T>)
} else {
use DataType::*;
match (self.dtype(), series.dtype()) {
(Int64, Datetime) | (Int32, Date) => {
let ca = &*(series_trait as *const dyn SeriesTrait as *const ChunkedArray<T>);
Ok(ca)
&*(series_trait as *const dyn SeriesTrait as *const ChunkedArray<T>)
}
_ => Err(PolarsError::DataTypeMisMatch(
format!(
"cannot unpack series {:?} into matching type {:?}",
series,
self.dtype()
)
.into(),
)),
_ => panic!(
"cannot unpack series {:?} into matching type {:?}",
series,
self.dtype()
),
}
}
}
Expand All @@ -249,7 +244,7 @@ impl<T> ChunkedArray<T> {
if self.dtype() == series.dtype() {
// Safety
// dtype will be correct.
unsafe { self.unpack_series_matching_physical_type(series) }
Ok(unsafe { self.unpack_series_matching_physical_type(series) })
} else {
Err(PolarsError::DataTypeMisMatch(
format!(
Expand Down
33 changes: 25 additions & 8 deletions polars/polars-core/src/series/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,35 @@ where
// we now only create the potentially wrong dtype for a short time.
// Note that the physical type correctness is checked!
// The ChunkedArray with the wrong dtype is dropped after this operation
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let out = self - rhs;
Ok(out.into_series())
}
fn add_to(&self, rhs: &Series) -> Result<Series> {
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
// Safety:
// see subtract
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let out = self + rhs;
Ok(out.into_series())
}
fn multiply(&self, rhs: &Series) -> Result<Series> {
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
// Safety:
// see subtract
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let out = self * rhs;
Ok(out.into_series())
}
fn divide(&self, rhs: &Series) -> Result<Series> {
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
// Safety:
// see subtract
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let out = self / rhs;
Ok(out.into_series())
}
fn remainder(&self, rhs: &Series) -> Result<Series> {
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
// Safety:
// see subtract
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let out = self % rhs;
Ok(out.into_series())
}
Expand Down Expand Up @@ -134,7 +142,12 @@ pub mod checked {
ChunkedArray<T>: IntoSeries,
{
fn checked_div(&self, rhs: &Series) -> Result<Series> {
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
// Safety:
// There will be UB if a ChunkedArray is alive with the wrong datatype.
// we now only create the potentially wrong dtype for a short time.
// Note that the physical type correctness is checked!
// The ChunkedArray with the wrong dtype is dropped after this operation
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let (l, r) = align_chunks_binary(self, rhs);

Ok((l)
Expand All @@ -159,7 +172,9 @@ pub mod checked {

impl NumOpsDispatchChecked for Float32Chunked {
fn checked_div(&self, rhs: &Series) -> Result<Series> {
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
// Safety:
// see check_div for chunkedarray<T>
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let (l, r) = align_chunks_binary(self, rhs);

Ok((l)
Expand Down Expand Up @@ -190,7 +205,9 @@ pub mod checked {

impl NumOpsDispatchChecked for Float64Chunked {
fn checked_div(&self, rhs: &Series) -> Result<Series> {
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs)? };
// Safety:
// see check_div
let rhs = unsafe { self.unpack_series_matching_physical_type(rhs) };
let (l, r) = align_chunks_binary(self, rhs);

Ok((l)
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,7 @@ where
let mut acc = series.pop().unwrap();

for s in series {
acc = f(acc, s)?
acc = f(acc, s)?;
}
Ok(acc)
}) as Arc<dyn SeriesUdf>);
Expand Down
6 changes: 2 additions & 4 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,10 +909,8 @@ impl LogicalPlanBuilder {
/// Apply a filter
pub fn filter(self, predicate: Expr) -> Self {
let predicate = if has_expr(&predicate, |e| matches!(e, Expr::Wildcard)) {
let it = self.0.schema().fields().iter().map(|field| {
replace_wildcard_with_column(predicate.clone(), Arc::new(field.name().clone()))
});
combine_predicates_expr(it)
let rewritten = rewrite_projections(vec![predicate], self.0.schema());
combine_predicates_expr(rewritten.into_iter())
} else {
predicate
};
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,3 +1169,15 @@ def test_with_row_count():

out = df.lazy().with_row_count().collect()
assert out["row_nr"].to_list() == [0, 1, 2]


def test_filter_with_all_expansion():
df = pl.DataFrame(
{
"b": [1, 2, None],
"c": [1, 2, None],
"a": [None, None, None],
}
)
out = df.filter(~pl.fold(True, lambda acc, s: acc & s.is_null(), pl.all()))
assert out.shape == (2, 3)

0 comments on commit c6611bc

Please sign in to comment.