Skip to content

Commit

Permalink
more ergonomic for lazy api (#2352)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 12, 2022
1 parent 73416c4 commit 6dd2778
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 35 deletions.
20 changes: 13 additions & 7 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2185,10 +2185,11 @@ where
}

/// Accumulate over multiple columns horizontally / row wise.
pub fn fold_exprs<F: 'static>(mut acc: Expr, f: F, mut exprs: Vec<Expr>) -> Expr
pub fn fold_exprs<F: 'static, E: AsRef<[Expr]>>(mut acc: Expr, f: F, exprs: E) -> Expr
where
F: Fn(Series, Series) -> Result<Series> + Send + Sync + Clone,
{
let mut exprs = exprs.as_ref().to_vec();
if exprs.iter().any(has_wildcard) {
exprs.push(acc);

Expand Down Expand Up @@ -2232,13 +2233,15 @@ where
}

/// Get the the sum of the values per row
pub fn sum_exprs(exprs: Vec<Expr>) -> Expr {
pub fn sum_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1, s2| Ok(&s1 + &s2);
fold_exprs(lit(0), func, exprs)
}

/// Get the the minimum value per row
pub fn max_exprs(exprs: Vec<Expr>) -> Expr {
/// Get the the maximum value per row
pub fn max_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| {
let mask = s1.gt(&s2);
s1.zip_with(&mask, &s2)
Expand All @@ -2247,7 +2250,8 @@ pub fn max_exprs(exprs: Vec<Expr>) -> Expr {
}

/// Get the the minimum value per row
pub fn min_exprs(exprs: Vec<Expr>) -> Expr {
pub fn min_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| {
let mask = s1.lt(&s2);
s1.zip_with(&mask, &s2)
Expand All @@ -2256,13 +2260,15 @@ pub fn min_exprs(exprs: Vec<Expr>) -> Expr {
}

/// Evaluate all the expressions with a bitwise or
pub fn any_exprs(exprs: Vec<Expr>) -> Expr {
pub fn any_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| Ok(s1.bool()?.bitor(s2.bool()?).into_series());
fold_exprs(lit(false), func, exprs)
}

/// Evaluate all the expressions with a bitwise and
pub fn all_exprs(exprs: Vec<Expr>) -> Expr {
pub fn all_exprs<E: AsRef<[Expr]>>(exprs: E) -> Expr {
let exprs = exprs.as_ref().to_vec();
let func = |s1: Series, s2: Series| Ok(s1.bool()?.bitand(s2.bool()?).into_series());
fold_exprs(lit(true), func, exprs)
}
Expand Down
7 changes: 4 additions & 3 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,8 @@ impl LazyFrame {
}

/// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode).
pub fn explode(self, columns: Vec<Expr>) -> LazyFrame {
pub fn explode<E: AsRef<[Expr]>>(self, columns: E) -> LazyFrame {
let columns = columns.as_ref().to_vec();
let opt_state = self.get_opt_state();
let lp = self.get_plan_builder().explode(columns).build();
Self::from_logical_plan(lp, opt_state)
Expand Down Expand Up @@ -1004,7 +1005,7 @@ impl LazyGroupBy {
.collect::<Vec<_>>();

self.agg([col("*").exclude(&keys).head(n).list().keep_name()])
.explode(vec![col("*").exclude(&keys)])
.explode([col("*").exclude(&keys)])
}

/// Return last n rows of each group
Expand All @@ -1016,7 +1017,7 @@ impl LazyGroupBy {
.collect::<Vec<_>>();

self.agg([col("*").exclude(&keys).tail(n).list().keep_name()])
.explode(vec![col("*").exclude(&keys)])
.explode([col("*").exclude(&keys)])
}

/// Apply a function over the groups as a new `DataFrame`. It is not recommended that you use
Expand Down
50 changes: 25 additions & 25 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn test_lazy_with_column() {
let df = get_df()
.lazy()
.with_column(lit(10).alias("foo"))
.select(&[col("foo"), col("sepal.width")])
.select([col("foo"), col("sepal.width")])
.collect()
.unwrap();
println!("{:?}", df);
Expand All @@ -43,7 +43,7 @@ fn test_lazy_exec() {
let new = df
.clone()
.lazy()
.select(&[col("sepal.width"), col("variety")])
.select([col("sepal.width"), col("variety")])
.sort("sepal.width", false)
.collect();
println!("{:?}", new);
Expand All @@ -63,7 +63,7 @@ fn test_lazy_alias() {
let df = get_df();
let new = df
.lazy()
.select(&[col("sepal.width").alias("petals"), col("sepal.width")])
.select([col("sepal.width").alias("petals"), col("sepal.width")])
.collect()
.unwrap();
assert_eq!(new.get_column_names(), &["petals", "sepal.width"]);
Expand Down Expand Up @@ -108,7 +108,7 @@ fn test_lazy_udf() {
let df = get_df();
let new = df
.lazy()
.select(&[col("sepal.width").map(|s| Ok(s * 200.0), GetOutput::same_type())])
.select([col("sepal.width").map(|s| Ok(s * 200.0), GetOutput::same_type())])
.collect()
.unwrap();
assert_eq!(
Expand Down Expand Up @@ -159,9 +159,9 @@ fn test_lazy_pushdown_through_agg() {
col("sepal.length").min(),
col("petal.length").min().alias("foo"),
])
.select(&[col("foo")])
.select([col("foo")])
// second selection is to test if optimizer can handle that
.select(&[col("foo").alias("bar")])
.select([col("foo").alias("bar")])
.collect()
.unwrap();

Expand Down Expand Up @@ -210,7 +210,7 @@ fn test_lazy_shift() {
let df = get_df();
let new = df
.lazy()
.select(&[col("sepal.width").alias("foo").shift(2)])
.select([col("sepal.width").alias("foo").shift(2)])
.collect()
.unwrap();
assert_eq!(new.column("foo").unwrap().f64().unwrap().get(0), None);
Expand Down Expand Up @@ -263,7 +263,7 @@ fn test_lazy_binary_ops() {
let df = df!("a" => &[1, 2, 3, 4, 5, ]).unwrap();
let new = df
.lazy()
.select(&[col("a").eq(lit(2)).alias("foo")])
.select([col("a").eq(lit(2)).alias("foo")])
.collect()
.unwrap();
assert_eq!(new.column("foo").unwrap().sum::<i32>(), Some(1));
Expand All @@ -288,7 +288,7 @@ fn test_lazy_query_1() {
.filter(col("a").lt(lit(2)))
.groupby([col("b")])
.agg([col("b").first(), col("c").first()])
.select(&[col("b"), col("c_first")])
.select([col("b"), col("c_first")])
.collect()
.unwrap();
}
Expand All @@ -304,7 +304,7 @@ fn test_lazy_query_2() {
.alias("foo"),
)
.filter(col("a").lt(lit(2)))
.select(&[col("b"), col("a")]);
.select([col("b"), col("a")]);

let new = ldf.collect().unwrap();
assert_eq!(new.shape(), (1, 2));
Expand Down Expand Up @@ -340,11 +340,11 @@ fn test_lazy_query_4() {
.apply(|s: Series| Ok(&s - &(s.shift(1))), GetOutput::same_type())
.alias("diff_cases"),
])
.explode(vec![col("day"), col("diff_cases")])
.explode([col("day"), col("diff_cases")])
.join(
base_df,
vec![col("uid"), col("day")],
vec![col("uid"), col("day")],
[col("uid"), col("day")],
[col("uid"), col("day")],
JoinType::Inner,
)
.collect()
Expand Down Expand Up @@ -467,8 +467,8 @@ fn test_lazy_query_9() -> Result<()> {
.lazy()
.join(
cities.lazy(),
vec![col("Sales.City")],
vec![col("Cities.City")],
[col("Sales.City")],
[col("Cities.City")],
JoinType::Inner,
)
.groupby([col("Cities.Country")])
Expand Down Expand Up @@ -643,7 +643,7 @@ fn test_simplify_expr() {
#[test]
fn test_lazy_wildcard() {
let df = load_df();
let new = df.clone().lazy().select(&[col("*")]).collect().unwrap();
let new = df.clone().lazy().select([col("*")]).collect().unwrap();
assert_eq!(new.shape(), (5, 3));

let new = df
Expand Down Expand Up @@ -678,7 +678,7 @@ fn test_lazy_filter_and_rename() {
|s: Series| Ok(s.gt(3).into_series()),
GetOutput::from_type(DataType::Boolean),
))
.select(&[col("x")]);
.select([col("x")]);

let correct = df! {
"x" => &[4, 5]
Expand Down Expand Up @@ -742,7 +742,7 @@ fn test_lazy_predicate_pushdown_binary_expr() {
let df = load_df();
df.lazy()
.filter(col("a").eq(col("b")))
.select(&[col("c")])
.select([col("c")])
.collect()
.unwrap();
}
Expand Down Expand Up @@ -798,7 +798,7 @@ fn test_lazy_window_functions() {
// test if partition aggregation is correct
let out = df
.lazy()
.select(&[col("groups"), sum("values").over([col("groups")])])
.select([col("groups"), sum("values").over([col("groups")])])
.collect()
.unwrap();
assert_eq!(
Expand All @@ -814,8 +814,8 @@ fn test_lazy_double_projection() {
}
.unwrap();
df.lazy()
.select(&[col("foo").alias("bar")])
.select(&[col("bar")])
.select([col("foo").alias("bar")])
.select([col("bar")])
.collect()
.unwrap();
}
Expand All @@ -828,7 +828,7 @@ fn test_type_coercion() {
}
.unwrap();

let lp = df.lazy().select(&[col("foo") * col("bar")]).logical_plan;
let lp = df.lazy().select([col("foo") * col("bar")]).logical_plan;

let mut expr_arena = Arena::new();
let mut lp_arena = Arena::new();
Expand Down Expand Up @@ -1257,7 +1257,7 @@ fn test_multiple_explode() -> Result<()> {
col("b").list().alias("b_list"),
col("c").list().alias("c_list"),
])
.explode(vec![col("c_list"), col("b_list")])
.explode([col("c_list"), col("b_list")])
.collect()?;
assert_eq!(out.shape(), (5, 3));

Expand Down Expand Up @@ -1368,7 +1368,7 @@ fn test_fold_wildcard() -> Result<()> {
let out = df1
.clone()
.lazy()
.select([fold_exprs(lit(0), |a, b| Ok(&a + &b), vec![col("*")]).alias("foo")])
.select([fold_exprs(lit(0), |a, b| Ok(&a + &b), [col("*")]).alias("foo")])
.collect()?;

assert_eq!(
Expand All @@ -1379,7 +1379,7 @@ fn test_fold_wildcard() -> Result<()> {
// test if we don't panic due to wildcard
let _out = df1
.lazy()
.select([all_exprs(vec![col("*").is_not_null()])])
.select([all_exprs([col("*").is_not_null()])])
.collect()?;
Ok(())
}
Expand Down

0 comments on commit 6dd2778

Please sign in to comment.