Skip to content

Commit

Permalink
fix hmin/hmax of exprs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 23, 2022
1 parent 7acc62f commit 1fe5e2e
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 278 deletions.
277 changes: 21 additions & 256 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
//! Domain specific language for the Lazy api.
use crate::logical_plan::Context;
use crate::prelude::*;
use crate::utils::{has_expr, has_root_literal_expr, has_wildcard};
use crate::utils::{has_expr, has_root_literal_expr};
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_core::export::arrow::{array::BooleanArray, bitmap::MutableBitmap};
use polars_core::prelude::*;

use std::fmt::{Debug, Formatter};
use std::ops::{BitAnd, BitOr, Deref};
use std::ops::Deref;
use std::{
fmt,
ops::{Add, Div, Mul, Rem, Sub},
sync::Arc,
};
// reexport the lazy method
pub use crate::frame::IntoLazy;
pub use crate::functions::*;
pub use crate::logical_plan::lit;

use polars_arrow::array::default_arrays::FromData;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
Expand Down Expand Up @@ -156,6 +158,16 @@ impl GetOutput {
}))
}

pub fn super_type() -> Self {
Self::map_dtypes(|dtypes| {
let mut st = dtypes[0].clone();
for dt in &dtypes[1..] {
st = get_supertype(&st, dt).unwrap()
}
st
})
}

pub fn map_dtypes<F>(f: F) -> Self
where
F: 'static + Fn(&[&DataType]) -> DataType + Send + Sync,
Expand Down Expand Up @@ -366,6 +378,12 @@ pub enum Expr {
},
}

impl Default for Expr {
fn default() -> Self {
Expr::Literal(LiteralValue::Null)
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum Excluded {
Name(Arc<str>),
Expand Down Expand Up @@ -1306,7 +1324,7 @@ impl Expr {
}
},
&[fill_value],
GetOutput::map_dtypes(|dtypes| get_supertype(dtypes[0], dtypes[1]).unwrap()),
GetOutput::super_type(),
)
.with_fmt("fill_null")
}
Expand Down Expand Up @@ -2075,259 +2093,6 @@ impl Expr {
}
}

/// Create a Column Expression based on a column name.
///
/// # Arguments
///
/// * `name` - A string slice that holds the name of the column
///
/// # Examples
///
/// ```ignore
/// // select a column name
/// col("foo")
/// ```
///
/// ```ignore
/// // select all columns by using a wildcard
/// col("*")
/// ```
///
/// ```ignore
/// // select specific column by writing a regular expression that starts with `^` and ends with `$`
/// // only if regex features is activated
/// col("^foo.*$")
/// ```
pub fn col(name: &str) -> Expr {
match name {
"*" => Expr::Wildcard,
_ => Expr::Column(Arc::from(name)),
}
}

/// Select multiple columns by name
pub fn cols(names: Vec<String>) -> Expr {
Expr::Columns(names)
}

/// Select multiple columns by dtype.
pub fn dtype_col(dtype: &DataType) -> Expr {
Expr::DtypeColumn(vec![dtype.clone()])
}

/// Select multiple columns by dtype.
pub fn dtype_cols<DT: AsRef<[DataType]>>(dtype: DT) -> Expr {
let dtypes = dtype.as_ref().to_vec();
Expr::DtypeColumn(dtypes)
}

/// Count the number of values in this Expression.
pub fn count(name: &str) -> Expr {
match name {
"" => col(name).count().alias("count"),
_ => col(name).count(),
}
}

/// Sum all the values in this Expression.
pub fn sum(name: &str) -> Expr {
col(name).sum()
}

/// Find the minimum of all the values in this Expression.
pub fn min(name: &str) -> Expr {
col(name).min()
}

/// Find the maximum of all the values in this Expression.
pub fn max(name: &str) -> Expr {
col(name).max()
}

/// Find the mean of all the values in this Expression.
pub fn mean(name: &str) -> Expr {
col(name).mean()
}

/// Find the mean of all the values in this Expression.
pub fn avg(name: &str) -> Expr {
col(name).mean()
}

/// Find the median of all the values in this Expression.
pub fn median(name: &str) -> Expr {
col(name).median()
}

/// Find a specific quantile of all the values in this Expression.
pub fn quantile(name: &str, quantile: f64, interpol: QuantileInterpolOptions) -> Expr {
col(name).quantile(quantile, interpol)
}

macro_rules! prepare_binary_function {
($f:ident) => {
move |s: &mut [Series]| {
let s0 = std::mem::take(&mut s[0]);
let s1 = std::mem::take(&mut s[1]);

$f(s0, s1)
}
};
}

/// Apply a closure on the two columns that are evaluated from `Expr` a and `Expr` b.
pub fn map_binary<F: 'static>(a: Expr, b: Expr, f: F, output_type: GetOutput) -> Expr
where
F: Fn(Series, Series) -> Result<Series> + Send + Sync,
{
let function = prepare_binary_function!(f);
a.map_many(function, &[b], output_type)
}

pub fn apply_binary<F: 'static>(a: Expr, b: Expr, f: F, output_type: GetOutput) -> Expr
where
F: Fn(Series, Series) -> Result<Series> + Send + Sync,
{
let function = prepare_binary_function!(f);
a.apply_many(function, &[b], output_type)
}

/// Accumulate over multiple columns horizontally / row wise.
pub fn fold_exprs<F: 'static, E: AsRef<[Expr]>>(mut acc: Expr, f: F, exprs: E) -> Expr
where
F: Fn(Series, Series) -> Result<Series> + Send + Sync + Clone,
{
let mut exprs = exprs.as_ref().to_vec();
if exprs.iter().any(has_wildcard) {
exprs.push(acc);

let function = NoEq::new(Arc::new(move |series: &mut [Series]| {
let mut series = series.to_vec();
let mut acc = series.pop().unwrap();

for s in series {
acc = f(acc, s)?;
}
Ok(acc)
}) as Arc<dyn SeriesUdf>);

// Todo! make sure that output type is correct
Expr::Function {
input: exprs,
function,
output_type: GetOutput::same_type(),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "",
},
}
} else {
for e in exprs {
acc = map_binary(
acc,
e,
f.clone(),
GetOutput::map_dtypes(|dt| get_supertype(dt[0], dt[1]).unwrap()),
);
}
acc
}
}

/// Get the the sum of the values per row
pub fn sum_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1, s2| Ok(&s1 + &s2);
fold_exprs(lit(0), func, exprs)
}

/// Get the the maximum value per row
pub fn max_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| {
let mask = s1.gt(&s2);
s1.zip_with(&mask, &s2)
};
fold_exprs(lit(0), func, exprs)
}

/// Get the the minimum value per row
pub fn min_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| {
let mask = s1.lt(&s2);
s1.zip_with(&mask, &s2)
};
fold_exprs(lit(0), func, exprs)
}

/// Evaluate all the expressions with a bitwise or
pub fn any_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| Ok(s1.bool()?.bitor(s2.bool()?).into_series());
fold_exprs(lit(false), func, exprs)
}

/// Evaluate all the expressions with a bitwise and
pub fn all_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| Ok(s1.bool()?.bitand(s2.bool()?).into_series());
fold_exprs(lit(true), func, exprs)
}

/// [Not](Expr::Not) expression.
pub fn not(expr: Expr) -> Expr {
Expr::Not(Box::new(expr))
}

/// [IsNull](Expr::IsNotNull) expression
pub fn is_null(expr: Expr) -> Expr {
Expr::IsNull(Box::new(expr))
}

/// [IsNotNull](Expr::IsNotNull) expression.
pub fn is_not_null(expr: Expr) -> Expr {
Expr::IsNotNull(Box::new(expr))
}

/// [Cast](Expr::Cast) expression.
pub fn cast(expr: Expr, data_type: DataType) -> Expr {
Expr::Cast {
expr: Box::new(expr),
data_type,
strict: false,
}
}

pub trait Range<T> {
fn into_range(self, high: T) -> Expr;
}

macro_rules! impl_into_range {
($dt: ty) => {
impl Range<$dt> for $dt {
fn into_range(self, high: $dt) -> Expr {
Expr::Literal(LiteralValue::Range {
low: self as i64,
high: high as i64,
data_type: DataType::Int32,
})
}
}
};
}

impl_into_range!(i32);
impl_into_range!(i64);
impl_into_range!(u32);

/// Create a range literal.
pub fn range<T: Range<T>>(low: T, high: T) -> Expr {
low.into_range(high)
}

// Arithmetic ops
impl Add for Expr {
type Output = Expr;
Expand Down

0 comments on commit 1fe5e2e

Please sign in to comment.