Skip to content

Commit

Permalink
fix consistency and ergonomics of take expression
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 10, 2021
1 parent 96fe344 commit 41f419c
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 16 deletions.
12 changes: 6 additions & 6 deletions polars/polars-core/src/chunked_array/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ where
lhs.comparison(&rhs, |x, y| {
comparison::compare(x, y, comparison::Operator::Eq)
})
.expect("should not fail.")
.unwrap()
}
}
}
Expand Down Expand Up @@ -120,7 +120,7 @@ where
lhs.comparison(&rhs, |x, y| {
comparison::compare(x, y, comparison::Operator::Neq)
})
.expect("should not fail.")
.unwrap()
}
}
}
Expand Down Expand Up @@ -148,7 +148,7 @@ where
lhs.comparison(&rhs, |x, y| {
comparison::compare(x, y, comparison::Operator::Gt)
})
.expect("should not fail.")
.unwrap()
}
}
}
Expand Down Expand Up @@ -176,7 +176,7 @@ where
lhs.comparison(&rhs, |x, y| {
comparison::compare(x, y, comparison::Operator::GtEq)
})
.expect("should not fail.")
.unwrap()
}
}
}
Expand Down Expand Up @@ -204,7 +204,7 @@ where
lhs.comparison(&rhs, |x, y| {
comparison::compare(x, y, comparison::Operator::Lt)
})
.expect("should not fail.")
.unwrap()
}
}
}
Expand Down Expand Up @@ -232,7 +232,7 @@ where
lhs.comparison(&rhs, |x, y| {
comparison::compare(x, y, comparison::Operator::LtEq)
})
.expect("should not fail.")
.unwrap()
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ impl ListChunked {

/// Apply a closure `F` elementwise.
#[cfg(feature = "private")]
pub fn apply_amortized<'a, F>(&'a self, f: F) -> Self
pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self
where
F: Fn(UnsafeSeries<'a>) -> Series + Copy,
F: FnMut(UnsafeSeries<'a>) -> Series,
{
if self.is_empty() {
return self.clone();
Expand All @@ -128,9 +128,9 @@ impl ListChunked {
ca
}

pub fn try_apply_amortized<'a, F>(&'a self, f: F) -> Result<Self>
pub fn try_apply_amortized<'a, F>(&'a self, mut f: F) -> Result<Self>
where
F: Fn(UnsafeSeries<'a>) -> Result<Series> + Copy,
F: FnMut(UnsafeSeries<'a>) -> Result<Series>,
{
if self.is_empty() {
return Ok(self.clone());
Expand Down
32 changes: 32 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,35 @@ pub(crate) fn index_to_chunked_index<
pub(crate) unsafe fn copy_from_slice_unchecked<T>(src: &[T], dst: &mut [T]) {
std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), dst.len());
}

#[cfg(test)]
mod test {
use super::*;
use crate::prelude::*;

#[test]
fn test_align_chunks() {
let a = Int32Chunked::new_from_slice("", &[1, 2, 3, 4]);
let mut b = Int32Chunked::new_from_slice("", &[1]);
let b2 = Int32Chunked::new_from_slice("", &[2, 3, 4]);

b.append(&b2);
let (a, b) = align_chunks_binary(&a, &b);
assert_eq!(
a.chunk_id().collect::<Vec<_>>(),
b.chunk_id().collect::<Vec<_>>()
);

let a = Int32Chunked::new_from_slice("", &[1, 2, 3, 4]);
let mut b = Int32Chunked::new_from_slice("", &[1]);
let b1 = b.clone();
b.append(&b1);
b.append(&b1);
b.append(&b1);
let (a, b) = align_chunks_binary(&a, &b);
assert_eq!(
a.chunk_id().collect::<Vec<_>>(),
b.chunk_id().collect::<Vec<_>>()
);
}
}
3 changes: 1 addition & 2 deletions polars/polars-lazy/src/physical_plan/executors/groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ fn groupby_helper(
let agg_columns = agg_columns?;

columns.extend(agg_columns.into_iter().flatten());
let df = DataFrame::new_no_checks(columns);
Ok(df)
DataFrame::new(columns)
}

impl Executor for GroupByExec {
Expand Down
73 changes: 71 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct TakeExpr {

impl TakeExpr {
fn finish(&self, df: &DataFrame, state: &ExecutionState, series: Series) -> Result<Series> {
let idx = self.idx.evaluate(df, state)?;
let idx = self.idx.evaluate(df, state)?.cast(&DataType::UInt32)?;
let idx_ca = idx.u32()?;

series.take(idx_ca)
Expand All @@ -37,7 +37,7 @@ impl PhysicalExpr for TakeExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?;
let idx = self.idx.evaluate(df, state)?;
let idx = self.idx.evaluate(df, state)?.cast(&DataType::UInt32)?;
let idx_ca = idx.u32()?;

let taken = ac
Expand All @@ -54,4 +54,73 @@ impl PhysicalExpr for TakeExpr {
fn to_field(&self, input_schema: &Schema) -> Result<Field> {
self.phys_expr.to_field(input_schema)
}

fn as_agg_expr(&self) -> Result<&dyn PhysicalAggregation> {
Ok(self)
}
}

impl PhysicalAggregation for TakeExpr {
// As a final aggregation a Sort returns a list array.
fn aggregate(
&self,
df: &DataFrame,
groups: &GroupTuples,
state: &ExecutionState,
) -> Result<Option<Series>> {
let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?;
let idx = self.idx.evaluate_on_groups(df, groups, state)?;
let idx = idx.series();

let mut all_unit_length = true;
let mut taken = if let Ok(idx) = idx.list() {
// cast the indices up front.
let idx = idx.cast(&DataType::List(Box::new(DataType::UInt32)))?;

let idx = idx.list().unwrap();
let ca: ListChunked = ac
.aggregated()
.list()?
.into_iter()
.zip(idx.into_iter())
.map(|(opt_s, opt_idx)| {
if let (Some(s), Some(idx)) = (opt_s, opt_idx) {
let idx = idx.u32()?;
let s = s.take(idx)?;
if s.len() != 1 {
all_unit_length = false;
}
Ok(Some(s))
} else {
Ok(None)
}
})
.collect::<Result<_>>()?;
ca
} else {
let idx = idx.cast(&DataType::UInt32)?;
let idx_ca = idx.u32()?;

ac.aggregated().list().unwrap().try_apply_amortized(|s| {
match s.as_ref().take(idx_ca) {
Ok(s) => {
if s.len() != 1 {
all_unit_length = false;
}
Ok(s)
}
e => e,
}
})?
};

taken.rename(ac.series().name());

if all_unit_length {
let s = taken.explode()?;
Ok(Some(s))
} else {
Ok(Some(taken.into_series()))
}
}
}
53 changes: 51 additions & 2 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_core::{df, prelude::*};

use crate::logical_plan::optimizer::simplify_expr::SimplifyExprRule;
use crate::prelude::*;
use itertools::assert_equal;
use polars_core::chunked_array::builder::get_list_builder;
use std::iter::FromIterator;

Expand Down Expand Up @@ -1601,11 +1602,13 @@ fn test_take_in_groups() -> Result<()> {
.select([col("B")
.take(lit(Series::new("", &[0u32])))
.over([col("fruits")])
.explode()
.alias("taken")])
.collect()?;

assert_eq!(Vec::from(out.column("taken")?.i32()?), &[Some(3), Some(5)]);
assert_eq!(
Vec::from(out.column("taken")?.i32()?),
&[Some(3), Some(3), Some(5), Some(5), Some(5)]
);
Ok(())
}

Expand Down Expand Up @@ -2064,3 +2067,49 @@ fn test_drop_and_select() -> Result<()> {
assert_eq!(out.get_column_names(), &["category"]);
Ok(())
}

#[test]
fn test_take_consistency() -> Result<()> {
let df = fruits_cars();
let out = df
.clone()
.lazy()
.select([col("A").arg_sort(true).take(lit(0))])
.collect()?;

assert_eq!(out.column("A")?.get(0), AnyValue::UInt32(4));

let out = df
.clone()
.lazy()
.stable_groupby([col("cars")])
.agg([col("A").arg_sort(true).take(lit(0))])
.collect()?;

let out = out.column("A")?;
let out = out.u32()?;
assert_eq!(Vec::from(out), &[Some(3), Some(0)]);

let out_df = df
.clone()
.lazy()
.stable_groupby([col("cars")])
.agg([
col("A"),
col("A").arg_sort(true).take(lit(0)).alias("1"),
col("A")
.take(col("A").arg_sort(true).take(lit(0)))
.alias("2"),
])
.collect()?;

let out = out_df.column("2")?;
let out = out.i32()?;
assert_eq!(Vec::from(out), &[Some(5), Some(2)]);

let out = out_df.column("1")?;
let out = out.u32()?;
assert_eq!(Vec::from(out), &[Some(3), Some(0)]);

Ok(())
}

0 comments on commit 41f419c

Please sign in to comment.