Skip to content

Commit

Permalink
[lazy] with then otherwise implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 9, 2020
1 parent dd1218d commit feabfbe
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 47 deletions.
65 changes: 42 additions & 23 deletions polars/src/chunked_array/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -905,22 +905,24 @@ pub trait ChunkZip<T> {

// TODO! fast paths and check mask has no null values.
macro_rules! impl_ternary {
($mask:ident, $truthy:ident, $other:ident) => {{
let val = $mask
.into_no_null_iter()
.zip($truthy)
.zip($other)
.map(
|((mask_val, true_val), false_val)| {
($mask:expr, $truthy:expr, $other:expr) => {{
if $mask.null_count() > 0 {
Err(PolarsError::HasNullValues)
} else {
let val = $mask
.into_no_null_iter()
.zip($truthy)
.zip($other)
.map(|((mask_val, true_val), false_val)| {
if mask_val {
true_val
} else {
false_val
}
},
)
.collect();
Ok(val)
})
.collect();
Ok(val)
}
}};
}

Expand All @@ -929,11 +931,36 @@ where
T: PolarsNumericType,
{
fn zip_with(&self, mask: &BooleanChunked, other: &ChunkedArray<T>) -> Result<ChunkedArray<T>> {
if mask.null_count() > 0 {
Err(PolarsError::HasNullValues)

let self_len = self.len();
let other_len = other.len();
let mask_len = mask.len();

if self_len != mask_len || other_len != mask_len {
match (self_len, other_len) {
(1, 1) => {
let self_ = self.expand_at_index(mask_len, 0);
let other = other.expand_at_index(mask_len, 0);
println!("{:?}", (other.len(), self_.len()));
impl_ternary!(mask, &self_, &other)
}
(_, 1) => {
let other = other.expand_at_index(mask_len, 0);
impl_ternary!(mask, self, &other)
}
(1, _) => {
let self_ = self.expand_at_index(mask_len, 0);
impl_ternary!(mask, &self_, other)
}
(_, _) => {
Err(PolarsError::ShapeMisMatch)
}
}
} else {
impl_ternary!(mask, self, other)
}


}

fn zip_with_series(&self, mask: &BooleanChunked, other: &Series) -> Result<ChunkedArray<T>> {
Expand All @@ -944,11 +971,7 @@ where

impl ChunkZip<BooleanType> for BooleanChunked {
fn zip_with(&self, mask: &BooleanChunked, other: &BooleanChunked) -> Result<BooleanChunked> {
if mask.null_count() > 0 {
Err(PolarsError::HasNullValues)
} else {
impl_ternary!(mask, self, other)
}
impl_ternary!(mask, self, other)
}

fn zip_with_series(
Expand All @@ -963,11 +986,7 @@ impl ChunkZip<BooleanType> for BooleanChunked {

impl ChunkZip<Utf8Type> for Utf8Chunked {
fn zip_with(&self, mask: &BooleanChunked, other: &Utf8Chunked) -> Result<Utf8Chunked> {
if mask.null_count() > 0 {
Err(PolarsError::HasNullValues)
} else {
impl_ternary!(mask, self, other)
}
impl_ternary!(mask, self, other)
}

fn zip_with_series(
Expand Down
64 changes: 58 additions & 6 deletions polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ pub enum Expr {
quantile: f64,
},
AggSum(Box<Expr>),
AggGroups(Box<Expr>), // ScalarFunction {
// name: String,
// args: Vec<Expr>,
// return_type: ArrowDataType,
// },
// Wildcard
AggGroups(Box<Expr>),
Ternary {
predicate: Box<Expr>,
truthy: Box<Expr>,
falsy: Box<Expr>,
}, // ScalarFunction {
// name: String,
// args: Vec<Expr>,
// return_type: ArrowDataType,
// },
// Wildcard
}

impl Expr {
Expand Down Expand Up @@ -93,6 +98,7 @@ impl Expr {
AggNUnique(_) => Ok(ArrowDataType::UInt32),
AggQuantile { expr, .. } => expr.get_type(schema),
Cast { data_type, .. } => Ok(data_type.clone()),
Ternary { truthy, .. } => truthy.get_type(schema),
}
}

Expand Down Expand Up @@ -176,6 +182,7 @@ impl Expr {
field.is_nullable(),
))
}
Ternary { truthy, .. } => truthy.to_field(schema),
}
}
}
Expand Down Expand Up @@ -206,6 +213,11 @@ impl fmt::Debug for Expr {
AggGroups(expr) => write!(f, "AGGREGATE GROUPS {:?}", expr),
AggQuantile { expr, .. } => write!(f, "AGGREGATE QUANTILE {:?}", expr),
Cast { expr, data_type } => write!(f, "CAST {:?} TO {:?}", expr, data_type),
Ternary {
predicate,
truthy,
falsy,
} => write!(f, "WHEN {:?} {:?} OTHERWISE {:?}", predicate, truthy, falsy),
}
}
}
Expand Down Expand Up @@ -262,6 +274,46 @@ pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr {
}
}

pub struct When {
predicate: Expr,
}

pub struct WhenThen {
predicate: Expr,
then: Expr,
}

impl When {
pub fn then(self, expr: Expr) -> WhenThen {
WhenThen {
predicate: self.predicate,
then: expr,
}
}
}

impl WhenThen {
pub fn otherwise(self, expr: Expr) -> Expr {
Expr::Ternary {
predicate: Box::new(self.predicate),
truthy: Box::new(self.then),
falsy: Box::new(expr),
}
}
}

pub fn when(predicate: Expr) -> When {
When { predicate }
}

pub fn ternary_expr(predicate: Expr, truthy: Expr, falsy: Expr) -> Expr {
Expr::Ternary {
predicate: Box::new(predicate),
truthy: Box::new(truthy),
falsy: Box::new(falsy),
}
}

impl Expr {
/// Compare `Expr` with other `Expr` on equality
pub fn eq(self, other: Expr) -> Expr {
Expand Down
17 changes: 17 additions & 0 deletions polars/src/lazy/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,23 @@ mod test {
use crate::lazy::tests::get_df;
use crate::prelude::*;

#[test]
fn test_lazy_ternary() {
let df = get_df()
.lazy()
.with_column(
when(col("sepal.length").lt(lit(5.0)))
.then(lit(10))
.otherwise(lit(1)
)
.alias("new")
,
)
.collect()
.unwrap();
assert_eq!(Some(43), df.column("new").unwrap().sum::<i32>());
}

#[test]
fn test_lazy_with_column() {
let df = get_df()
Expand Down
22 changes: 22 additions & 0 deletions polars/src/lazy/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ impl TypeCoercion {
let expr = self.rewrite_expr(*expr, input_schema)?;
Ok(expr.agg_groups())
}
Ternary {
predicate,
truthy,
falsy,
} => {
let predicate = self.rewrite_expr(*predicate, input_schema)?;
let truthy = self.rewrite_expr(*truthy, input_schema)?;
let falsy = self.rewrite_expr(*falsy, input_schema)?;
let type_true = truthy.get_type(input_schema)?;
let type_false = falsy.get_type(input_schema)?;

if type_true == type_false {
Ok(ternary_expr(predicate, truthy, falsy))
} else {
let st = get_supertype(&type_true, &type_false)?;
Ok(ternary_expr(
predicate,
truthy.cast(st.clone()),
falsy.cast(st),
))
}
}
}
}

Expand Down
27 changes: 21 additions & 6 deletions polars/src/lazy/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,26 @@ impl PhysicalExpr for CastExpr {
series.cast_with_arrow_datatype(&self.data_type)
}
fn to_field(&self, input_schema: &Schema) -> Result<Field> {
let field = self.expr.to_field(input_schema)?;
Ok(Field::new(
field.name(),
self.data_type.clone(),
field.is_nullable(),
))
self.expr.to_field(input_schema)
}
}

#[derive(Debug)]
pub struct TernaryExpr {
pub predicate: Arc<dyn PhysicalExpr>,
pub truthy: Arc<dyn PhysicalExpr>,
pub falsy: Arc<dyn PhysicalExpr>,
}

impl PhysicalExpr for TernaryExpr {
fn evaluate(&self, df: &DataFrame) -> Result<Series> {
let mask_series = self.predicate.evaluate(df)?;
let mask = mask_series.bool()?;
let truthy = self.truthy.evaluate(df)?;
let falsy = self.falsy.evaluate(df)?;
truthy.zip_with(&mask, &falsy)
}
fn to_field(&self, input_schema: &Schema) -> Result<Field> {
self.truthy.to_field(input_schema)
}
}
14 changes: 14 additions & 0 deletions polars/src/lazy/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ impl DefaultPlanner {
let phys_expr = self.create_physical_expr(expr)?;
Ok(Arc::new(CastExpr::new(phys_expr, data_type.clone())))
}
Expr::Ternary {
predicate,
truthy,
falsy,
} => {
let predicate = self.create_physical_expr(predicate)?;
let truthy = self.create_physical_expr(truthy)?;
let falsy = self.create_physical_expr(falsy)?;
Ok(Arc::new(TernaryExpr {
predicate,
truthy,
falsy,
}))
}
}
}
}
2 changes: 1 addition & 1 deletion py-polars/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-polars"
version = "0.0.3"
version = "0.0.4"
authors = ["ritchie46 <ritchie46@gmail.com>"]
edition = "2018"

Expand Down
5 changes: 1 addition & 4 deletions py-polars/pypolars/frame.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations
import os

if not os.environ.get("DOC_BUILDING", False):
from .pypolars import PyDataFrame, PySeries, PyLazyFrame
from .pypolars import PyDataFrame, PySeries, PyLazyFrame
from typing import (
Dict,
Sequence,
Expand Down
5 changes: 1 addition & 4 deletions py-polars/pypolars/lazy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations
import os
from typing import Union, List

from pypolars.frame import DataFrame, wrap_df

if not os.environ.get("DOC_BUILDING", False):
from .pypolars import PyLazyFrame, col, lit, binary_expr, PyExpr, PyLazyGroupBy
from .pypolars import PyLazyFrame, col, lit, binary_expr, PyExpr, PyLazyGroupBy, when


def lazy(self) -> "LazyFrame":
Expand Down
4 changes: 1 addition & 3 deletions py-polars/pypolars/series.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations
import os

if not os.environ.get("DOC_BUILDING", False):
from .pypolars import PySeries
from .pypolars import PySeries
import numpy as np
from typing import Optional, List, Sequence, Union, Any, Callable
from .ffi import ptr_to_numpy
Expand Down

0 comments on commit feabfbe

Please sign in to comment.