Skip to content

Commit

Permalink
col(dtypes).exclude() (#4001)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 13, 2022
1 parent 12f639f commit 4dff8f8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
64 changes: 37 additions & 27 deletions polars/polars-lazy/src/logical_plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use polars_arrow::index::IndexToUsize;
/// expression chain.
pub(super) fn replace_wildcard_with_column(mut expr: Expr, column_name: Arc<str>) -> Expr {
expr.mutate().apply(|e| {
match &e {
match e {
Expr::Wildcard => {
*e = Expr::Column(column_name.clone());
}
Expr::Exclude(input, _) => {
*e = replace_wildcard_with_column(*input.clone(), column_name.clone());
*e = replace_wildcard_with_column(std::mem::take(input), column_name.clone());
}
_ => {}
}
Expand Down Expand Up @@ -130,21 +130,6 @@ fn replace_regex(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema) {
let expr = rewrite_special_aliases(expr.clone());
result.push(expr)
}

// // only in simple expression (no binary expression)
// // we pattern match regex columns
// if roots.len() == 1 {
// let name = &*roots[0];
// if name.starts_with('^') && name.ends_with('$') {
// expand_regex(expr, result, schema, name)
// } else {
// let expr = rewrite_special_aliases(expr.clone());
// result.push(expr)
// }
// } else {
// let expr = rewrite_special_aliases(expr.clone());
// result.push(expr)
// }
}

/// replace `columns(["A", "B"])..` with `col("A")..`, `col("B")..`
Expand All @@ -164,21 +149,44 @@ fn expand_columns(expr: &Expr, result: &mut Vec<Expr>, names: &[String]) {
}
}

/// This replaces the dtypes Expr with a Column Expr. It also removes the Exclude Expr from the
/// expression chain.
pub(super) fn replace_dtype_with_column(mut expr: Expr, column_name: Arc<str>) -> Expr {
expr.mutate().apply(|e| {
match e {
Expr::DtypeColumn(_) => {
*e = Expr::Column(column_name.clone());
}
Expr::Exclude(input, _) => {
*e = replace_dtype_with_column(std::mem::take(input), column_name.clone());
}
_ => {}
}
// always keep iterating all inputs
true
});
expr
}

/// replace `DtypeColumn` with `col("foo")..col("bar")`
fn expand_dtypes(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema, dtypes: &[DataType]) {
fn expand_dtypes(
expr: &Expr,
result: &mut Vec<Expr>,
schema: &Schema,
dtypes: &[DataType],
exclude: &[Arc<str>],
) {
for dtype in dtypes {
for field in schema.iter_fields().filter(|f| f.data_type() == dtype) {
let name = field.name();

let mut new_expr = expr.clone();
new_expr.mutate().apply(|e| {
if let Expr::DtypeColumn(_) = &e {
*e = Expr::Column(Arc::from(name.as_str()));
}
// always keep iterating all inputs
true
});
// skip excluded names
if exclude.iter().any(|excl| excl.as_ref() == name.as_str()) {
continue;
}

let new_expr = expr.clone();
let new_expr = replace_dtype_with_column(new_expr, Arc::from(name.as_str()));
let new_expr = rewrite_special_aliases(new_expr);
result.push(new_expr)
}
Expand Down Expand Up @@ -370,7 +378,9 @@ pub(crate) fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema, keys: &[Exp
if let Expr::Columns(names) = e {
expand_columns(&expr, &mut result, names)
} else if let Expr::DtypeColumn(dtypes) = e {
expand_dtypes(&expr, &mut result, schema, dtypes)
// keep track of column excluded from the dtypes
let exclude = prepare_excluded(&expr, schema, keys);
expand_dtypes(&expr, &mut result, schema, dtypes, &exclude)
}
continue;
}
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/test_expr_multi_cols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import polars as pl


def test_exclude_name_from_dtypes() -> None:
df = pl.DataFrame({"a": ["a"], "b": ["b"]})

assert df.with_column(pl.col(pl.Utf8).exclude("a").suffix("_foo")).frame_equal(
pl.DataFrame({"a": ["a"], "b": ["b"], "b_foo": ["b"]})
)

0 comments on commit 4dff8f8

Please sign in to comment.