Skip to content

Commit

Permalink
setup serializable function + null_count expr (#3247)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 28, 2022
1 parent f35f348 commit 1b4a516
Show file tree
Hide file tree
Showing 17 changed files with 203 additions and 47 deletions.
10 changes: 9 additions & 1 deletion polars/polars-lazy/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use polars_core::utils::get_supertype;
use std::fmt::{Debug, Formatter};
use std::ops::Deref;

use crate::dsl::function_expr::FunctionExpr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -292,7 +293,7 @@ pub enum Expr {
falsy: Box<Expr>,
},
#[cfg_attr(feature = "serde", serde(skip))]
Function {
AnonymousFunction {
/// function arguments
input: Vec<Expr>,
/// function to apply
Expand All @@ -301,6 +302,13 @@ pub enum Expr {
output_type: GetOutput,
options: FunctionOptions,
},
Function {
/// function arguments
input: Vec<Expr>,
/// function to apply
function: FunctionExpr,
options: FunctionOptions,
},
Shift {
input: Box<Expr>,
periods: i64,
Expand Down
45 changes: 45 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use super::*;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug)]
pub enum FunctionExpr {
NullCount,
}

impl FunctionExpr {
pub(crate) fn get_field(
&self,
_input_schema: &Schema,
_cntxt: Context,
fields: &[Field],
) -> Result<Field> {
use FunctionExpr::*;
match self {
NullCount => Ok(Field::new(fields[0].name(), IDX_DTYPE)),
}
}
}

macro_rules! wrap {
($e:expr) => {
NoEq::new(Arc::new($e))
};
}

impl From<FunctionExpr> for NoEq<Arc<dyn SeriesUdf>> {
fn from(func: FunctionExpr) -> Self {
use FunctionExpr::*;
match func {
NullCount => {
let f = |s: &mut [Series]| {
let s = &s[0];
Ok(Series::new(s.name(), [s.null_count() as IdxSize]))
};
wrap!(f)
}
}
}
}
12 changes: 6 additions & 6 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ pub fn argsort_by<E: AsRef<[Expr]>>(by: E, reverse: &[bool]) -> Expr {
polars_core::functions::argsort_by(by, &reverse).map(|ca| ca.into_series())
}) as Arc<dyn SeriesUdf>);

Expr::Function {
Expr::AnonymousFunction {
input: by.as_ref().to_vec(),
function,
output_type: GetOutput::from_type(IDX_DTYPE),
Expand All @@ -187,7 +187,7 @@ pub fn concat_str(s: Vec<Expr>, sep: &str) -> Expr {
let function = NoEq::new(Arc::new(move |s: &mut [Series]| {
polars_core::functions::concat_str(s, &sep).map(|ca| ca.into_series())
}) as Arc<dyn SeriesUdf>);
Expr::Function {
Expr::AnonymousFunction {
input: s,
function,
output_type: GetOutput::from_type(DataType::Utf8),
Expand Down Expand Up @@ -217,7 +217,7 @@ pub fn concat_lst(s: Vec<Expr>) -> Expr {
};
first_ca.lst_concat(other).map(|ca| ca.into_series())
}) as Arc<dyn SeriesUdf>);
Expr::Function {
Expr::AnonymousFunction {
input: s,
function,
output_type: GetOutput::map_dtype(|dt| DataType::List(Box::new(dt.clone()))),
Expand Down Expand Up @@ -395,7 +395,7 @@ pub fn datetime(args: DatetimeArgs) -> Expr {

Ok(ca.into_datetime(TimeUnit::Milliseconds, None).into_series())
}) as Arc<dyn SeriesUdf>);
Expr::Function {
Expr::AnonymousFunction {
input: vec![
year,
month,
Expand Down Expand Up @@ -472,7 +472,7 @@ pub fn duration(args: DurationArgs) -> Expr {
nanoseconds.cast(&DataType::Duration(TimeUnit::Nanoseconds))
}) as Arc<dyn SeriesUdf>);

Expr::Function {
Expr::AnonymousFunction {
input: vec![
args.days.unwrap_or_else(|| lit(0i64)),
args.seconds.unwrap_or_else(|| lit(0i64)),
Expand Down Expand Up @@ -678,7 +678,7 @@ where
}) as Arc<dyn SeriesUdf>);

// Todo! make sure that output type is correct
Expr::Function {
Expr::AnonymousFunction {
input: exprs,
function,
output_type: GetOutput::same_type(),
Expand Down
79 changes: 58 additions & 21 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub use cat::*;
#[cfg(feature = "temporal")]
mod dt;
mod expr;
pub(crate) mod function_expr;
#[cfg(feature = "compile")]
mod functions;
#[cfg(feature = "list")]
Expand Down Expand Up @@ -38,6 +39,7 @@ pub use expr::*;
pub use functions::*;
pub use options::*;

use crate::dsl::function_expr::FunctionExpr;
use polars_arrow::array::default_arrays::FromData;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
Expand Down Expand Up @@ -186,22 +188,36 @@ impl Expr {
where
F: Fn(FunctionOptions) -> FunctionOptions,
{
if let Self::Function {
input,
function,
output_type,
mut options,
} = self
{
options = func(options);
Self::Function {
match self {
Self::AnonymousFunction {
input,
function,
output_type,
options,
mut options,
} => {
options = func(options);
Self::AnonymousFunction {
input,
function,
output_type,
options,
}
}
Self::Function {
input,
function,
mut options,
} => {
options = func(options);
Self::Function {
input,
function,
options,
}
}
_ => {
panic!("implementation error")
}
} else {
panic!("implementation error")
}
}

Expand Down Expand Up @@ -558,7 +574,7 @@ impl Expr {
{
let f = move |s: &mut [Series]| function(std::mem::take(&mut s[0]));

Expr::Function {
Expr::AnonymousFunction {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
Expand All @@ -581,7 +597,7 @@ impl Expr {
let mut input = vec![self];
input.extend_from_slice(arguments);

Expr::Function {
Expr::AnonymousFunction {
input,
function: NoEq::new(Arc::new(function)),
output_type,
Expand All @@ -607,7 +623,7 @@ impl Expr {
{
let f = move |s: &mut [Series]| function(std::mem::take(&mut s[0]));

Expr::Function {
Expr::AnonymousFunction {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
Expand All @@ -632,7 +648,7 @@ impl Expr {
{
let f = move |s: &mut [Series]| function(std::mem::take(&mut s[0]));

Expr::Function {
Expr::AnonymousFunction {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
Expand All @@ -655,7 +671,7 @@ impl Expr {
{
let f = move |s: &mut [Series]| function(std::mem::take(&mut s[0]));

Expr::Function {
Expr::AnonymousFunction {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
Expand All @@ -668,6 +684,19 @@ impl Expr {
}
}

fn apply_private(self, function_expr: FunctionExpr, fmt_str: &'static str) -> Self {
Expr::Function {
input: vec![self],
function: function_expr,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: false,
fmt_str,
},
}
}

/// Apply a function/closure over the groups with many arguments. This should only be used in a groupby aggregation.
///
/// See the [`Expr::apply`] function for the differences between [`map`](Expr::map) and [`apply`](Expr::apply).
Expand All @@ -678,7 +707,7 @@ impl Expr {
let mut input = vec![self];
input.extend_from_slice(arguments);

Expr::Function {
Expr::AnonymousFunction {
input,
function: NoEq::new(Arc::new(function)),
output_type,
Expand Down Expand Up @@ -1779,6 +1808,14 @@ impl Expr {
options
})
}
/// Get the null count of the column/group
pub fn null_count(self) -> Expr {
self.apply_private(FunctionExpr::NullCount, "null_count")
.with_function_options(|mut options| {
options.auto_explode = true;
options
})
}

#[cfg(feature = "strings")]
pub fn str(self) -> string::StringNameSpace {
Expand Down Expand Up @@ -1860,7 +1897,7 @@ where
{
let input = expr.as_ref().to_vec();

Expr::Function {
Expr::AnonymousFunction {
input,
function: NoEq::new(Arc::new(function)),
output_type,
Expand All @@ -1887,7 +1924,7 @@ where
{
let input = expr.as_ref().to_vec();

Expr::Function {
Expr::AnonymousFunction {
input,
function: NoEq::new(Arc::new(function)),
output_type,
Expand Down Expand Up @@ -1916,7 +1953,7 @@ where
{
let input = expr.as_ref().to_vec();

Expr::Function {
Expr::AnonymousFunction {
input,
function: NoEq::new(Arc::new(function)),
output_type,
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl StringNameSpace {
let function = NoEq::new(Arc::new(move |s: &mut [Series]| {
Ok(s[0].str_concat(&delimiter).into_series())
}) as Arc<dyn SeriesUdf>);
Expr::Function {
Expr::AnonymousFunction {
input: vec![self.0],
function,
output_type: GetOutput::from_type(DataType::Utf8),
Expand Down
21 changes: 19 additions & 2 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::dsl::function_expr::FunctionExpr;
use crate::logical_plan::Context;
use crate::prelude::*;
use polars_arrow::prelude::QuantileInterpolOptions;
Expand Down Expand Up @@ -74,12 +75,19 @@ pub enum AExpr {
truthy: Node,
falsy: Node,
},
Function {
AnonymousFunction {
input: Vec<Node>,
function: NoEq<Arc<dyn SeriesUdf>>,
output_type: GetOutput,
options: FunctionOptions,
},
Function {
/// function arguments
input: Vec<Node>,
/// function to apply
function: FunctionExpr,
options: FunctionOptions,
},
Shift {
input: Node,
periods: i64,
Expand Down Expand Up @@ -307,7 +315,7 @@ impl AExpr {
Ok(truthy)
}
}
Function {
AnonymousFunction {
output_type, input, ..
} => {
let fields = input
Expand All @@ -316,6 +324,15 @@ impl AExpr {
.collect::<Result<Vec<_>>>()?;
Ok(output_type.get_field(schema, ctxt, &fields))
}
Function {
function, input, ..
} => {
let fields = input
.iter()
.map(|node| arena.get(*node).to_field(schema, ctxt, arena))
.collect::<Result<Vec<_>>>()?;
function.get_field(schema, ctxt, &fields)
}
Shift { input, .. } => arena.get(*input).to_field(schema, ctxt, arena),
Slice { input, .. } => arena.get(*input).to_field(schema, ctxt, arena),
Wildcard => panic!("should be no wildcard at this point"),
Expand Down

0 comments on commit 1b4a516

Please sign in to comment.