Skip to content

Commit

Permalink
feat(rust, python): make map_alias fallible (#5532)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 17, 2022
1 parent c981a38 commit 90da88e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 45 deletions.
6 changes: 3 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
}

pub trait RenameAliasFn: Send + Sync {
fn call(&self, name: &str) -> String;
fn call(&self, name: &str) -> PolarsResult<String>;
}

impl<F: Fn(&str) -> String + Send + Sync> RenameAliasFn for F {
fn call(&self, name: &str) -> String {
impl<F: Fn(&str) -> PolarsResult<String> + Send + Sync> RenameAliasFn for F {
fn call(&self, name: &str) -> PolarsResult<String> {
self(name)
}
}
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ impl Expr {
/// Define an alias by mapping a function over the original root column name.
pub fn map_alias<F>(self, function: F) -> Expr
where
F: Fn(&str) -> String + 'static + Send + Sync,
F: Fn(&str) -> PolarsResult<String> + 'static + Send + Sync,
{
let function = SpecialEq::new(Arc::new(function) as Arc<dyn RenameAliasFn>);
Expr::RenameAlias {
Expand All @@ -1510,13 +1510,13 @@ impl Expr {
/// Add a suffix to the root column name.
pub fn suffix(self, suffix: &str) -> Expr {
let suffix = suffix.to_string();
self.map_alias(move |name| format!("{}{}", name, suffix))
self.map_alias(move |name| Ok(format!("{}{}", name, suffix)))
}

/// Add a prefix to the root column name.
pub fn prefix(self, prefix: &str) -> Expr {
let prefix = prefix.to_string();
self.map_alias(move |name| format!("{}{}", prefix, name))
self.map_alias(move |name| Ok(format!("{}{}", prefix, name)))
}

/// Exclude a column from a wildcard/regex selection.
Expand Down
24 changes: 18 additions & 6 deletions polars/polars-lazy/polars-plan/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub(crate) fn prepare_projection(
exprs: Vec<Expr>,
schema: &Schema,
) -> PolarsResult<(Vec<Expr>, Schema)> {
let exprs = rewrite_projections(exprs, schema, &[]);
let exprs = rewrite_projections(exprs, schema, &[])?;
let schema = utils::expressions_to_schema(&exprs, schema, Context::Default)?;
Ok((exprs, schema))
}
Expand Down Expand Up @@ -376,7 +376,11 @@ impl LogicalPlanBuilder {
_ => false,
}) {
let schema = try_delayed!(self.0.schema(), &self.0, into);
let rewritten = rewrite_projections(vec![predicate], &schema, &[]);
let rewritten = try_delayed!(
rewrite_projections(vec![predicate], &schema, &[]),
&self.0,
into
);
combine_predicates_expr(rewritten.into_iter())
} else {
predicate
Expand All @@ -399,8 +403,16 @@ impl LogicalPlanBuilder {
) -> Self {
let current_schema = try_delayed!(self.0.schema(), &self.0, into);
let current_schema = current_schema.as_ref();
let keys = rewrite_projections(keys, current_schema, &[]);
let aggs = rewrite_projections(aggs.as_ref().to_vec(), current_schema, keys.as_ref());
let keys = try_delayed!(
rewrite_projections(keys, current_schema, &[]),
&self.0,
into
);
let aggs = try_delayed!(
rewrite_projections(aggs.as_ref().to_vec(), current_schema, keys.as_ref()),
&self.0,
into
);

let mut schema = try_delayed!(
utils::expressions_to_schema(&keys, current_schema, Context::Default),
Expand Down Expand Up @@ -476,7 +488,7 @@ impl LogicalPlanBuilder {

pub fn sort(self, by_column: Vec<Expr>, reverse: Vec<bool>, null_last: bool) -> Self {
let schema = try_delayed!(self.0.schema(), &self.0, into);
let by_column = rewrite_projections(by_column, &schema, &[]);
let by_column = try_delayed!(rewrite_projections(by_column, &schema, &[]), &self.0, into);
LogicalPlan::Sort {
input: Box::new(self.0),
by_column,
Expand All @@ -491,7 +503,7 @@ impl LogicalPlanBuilder {

pub fn explode(self, columns: Vec<Expr>) -> Self {
let schema = try_delayed!(self.0.schema(), &self.0, into);
let columns = rewrite_projections(columns, &schema, &[]);
let columns = try_delayed!(rewrite_projections(columns, &schema, &[]), &self.0, into);

let mut schema = (**schema).clone();

Expand Down
81 changes: 50 additions & 31 deletions polars/polars-lazy/polars-plan/src/logical_plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub(super) fn replace_wildcard_with_column(mut expr: Expr, column_name: Arc<str>
expr
}

fn rewrite_special_aliases(expr: Expr) -> Expr {
fn rewrite_special_aliases(expr: Expr) -> PolarsResult<Expr> {
// the blocks are added by cargo fmt
#[allow(clippy::blocks_in_if_conditions)]
if has_expr(&expr, |e| {
Expand All @@ -36,31 +36,37 @@ fn rewrite_special_aliases(expr: Expr) -> Expr {
let name = roots
.get(0)
.expect("expected root column to keep expression name");
Expr::Alias(expr, name.clone())
Ok(Expr::Alias(expr, name.clone()))
}
Expr::RenameAlias { expr, function } => {
let name = get_single_leaf(&expr).unwrap();
let name = function.call(&name);
Expr::Alias(expr, Arc::from(name))
let name = function.call(&name)?;
Ok(Expr::Alias(expr, Arc::from(name)))
}
_ => panic!("`keep_name`, `suffix`, `prefix` should be last expression"),
}
} else {
expr
Ok(expr)
}
}

/// Take an expression with a root: col("*") and copies that expression for all columns in the schema,
/// with the exclusion of the `names` in the exclude expression.
/// The resulting expressions are written to result.
fn replace_wildcard(expr: &Expr, result: &mut Vec<Expr>, exclude: &[Arc<str>], schema: &Schema) {
fn replace_wildcard(
expr: &Expr,
result: &mut Vec<Expr>,
exclude: &[Arc<str>],
schema: &Schema,
) -> PolarsResult<()> {
for name in schema.iter_names() {
if !exclude.iter().any(|excluded| &**excluded == name) {
let new_expr = replace_wildcard_with_column(expr.clone(), Arc::from(name.as_str()));
let new_expr = rewrite_special_aliases(new_expr);
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
}
Ok(())
}

fn replace_nth(expr: &mut Expr, schema: &Schema) {
Expand All @@ -85,7 +91,12 @@ fn replace_nth(expr: &mut Expr, schema: &Schema) {
#[cfg(feature = "regex")]
/// This function takes an expression containing a regex in `col("..")` and expands the columns
/// that are selected by that regex in `result`.
fn expand_regex(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema, pattern: &str) {
fn expand_regex(
expr: &Expr,
result: &mut Vec<Expr>,
schema: &Schema,
pattern: &str,
) -> PolarsResult<()> {
let re = regex::Regex::new(pattern)
.unwrap_or_else(|_| panic!("invalid regular expression in column: {}", pattern));
for name in schema.iter_names() {
Expand All @@ -100,24 +111,25 @@ fn expand_regex(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema, pattern: &
_ => true,
});

let new_expr = rewrite_special_aliases(new_expr);
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
}
Ok(())
}

#[cfg(feature = "regex")]
/// This function searches for a regex expression in `col("..")` and expands the columns
/// that are selected by that regex in `result`. The regex should start with `^` and end with `$`.
fn replace_regex(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema) {
fn replace_regex(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema) -> PolarsResult<()> {
let roots = expr_to_leaf_column_names(expr);
let mut regex = None;
for name in &roots {
if name.starts_with('^') && name.ends_with('$') {
match regex {
None => {
regex = Some(name);
expand_regex(expr, result, schema, name)
expand_regex(expr, result, schema, name)?
}
Some(r) => {
assert_eq!(
Expand All @@ -129,13 +141,14 @@ fn replace_regex(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema) {
}
}
if regex.is_none() {
let expr = rewrite_special_aliases(expr.clone());
let expr = rewrite_special_aliases(expr.clone())?;
result.push(expr)
}
Ok(())
}

/// replace `columns(["A", "B"])..` with `col("A")..`, `col("B")..`
fn expand_columns(expr: &Expr, result: &mut Vec<Expr>, names: &[String]) {
fn expand_columns(expr: &Expr, result: &mut Vec<Expr>, names: &[String]) -> PolarsResult<()> {
for name in names {
let mut new_expr = expr.clone();
new_expr.mutate().apply(|e| {
Expand All @@ -146,9 +159,10 @@ fn expand_columns(expr: &Expr, result: &mut Vec<Expr>, names: &[String]) {
true
});

let new_expr = rewrite_special_aliases(new_expr);
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
Ok(())
}

/// This replaces the dtypes Expr with a Column Expr. It also removes the Exclude Expr from the
Expand Down Expand Up @@ -177,7 +191,7 @@ fn expand_dtypes(
schema: &Schema,
dtypes: &[DataType],
exclude: &[Arc<str>],
) {
) -> PolarsResult<()> {
for dtype in dtypes {
for field in schema.iter_fields().filter(|f| f.data_type() == dtype) {
let name = field.name();
Expand All @@ -189,17 +203,18 @@ fn expand_dtypes(

let new_expr = expr.clone();
let new_expr = replace_dtype_with_column(new_expr, Arc::from(name.as_str()));
let new_expr = rewrite_special_aliases(new_expr);
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
}
Ok(())
}

// schema is not used if regex not activated
#[allow(unused_variables)]
fn prepare_excluded(expr: &Expr, schema: &Schema, keys: &[Expr]) -> Vec<Arc<str>> {
fn prepare_excluded(expr: &Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult<Vec<Arc<str>>> {
let mut exclude = vec![];
expr.into_iter().for_each(|e| {
for e in expr {
if let Expr::Exclude(_, to_exclude) = e {
#[cfg(feature = "regex")]
{
Expand All @@ -212,7 +227,7 @@ fn prepare_excluded(expr: &Expr, schema: &Schema, keys: &[Expr]) -> Vec<Arc<str>
match to_exclude_single {
Excluded::Name(name) => {
let e = Expr::Column(name.clone());
replace_regex(&e, &mut buf, schema);
replace_regex(&e, &mut buf, schema)?;
// we cannot loop because of bchck
while let Some(col) = buf.pop() {
if let Expr::Column(name) = col {
Expand Down Expand Up @@ -247,7 +262,7 @@ fn prepare_excluded(expr: &Expr, schema: &Schema, keys: &[Expr]) -> Vec<Arc<str>
}
}
}
});
}
for mut expr in keys.iter() {
// Allow a number of aliases of a column expression, still exclude column from aggregation
loop {
Expand All @@ -265,7 +280,7 @@ fn prepare_excluded(expr: &Expr, schema: &Schema, keys: &[Expr]) -> Vec<Arc<str>
}
}
}
exclude
Ok(exclude)
}

// functions can have col(["a", "b"]) or col(Utf8) as inputs
Expand All @@ -274,7 +289,7 @@ fn expand_function_inputs(mut expr: Expr, schema: &Schema) -> Expr {
Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. }
if options.input_wildcard_expansion =>
{
*input = rewrite_projections(input.clone(), schema, &[]);
*input = rewrite_projections(input.clone(), schema, &[]).unwrap();
false
}
_ => true,
Expand Down Expand Up @@ -306,7 +321,11 @@ fn early_supertype(inputs: &[Expr], schema: &Schema) -> Option<DataType> {

/// In case of single col(*) -> do nothing, no selection is the same as select all
/// In other cases replace the wildcard with an expression with all columns
pub(crate) fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema, keys: &[Expr]) -> Vec<Expr> {
pub(crate) fn rewrite_projections(
exprs: Vec<Expr>,
schema: &Schema,
keys: &[Expr],
) -> PolarsResult<Vec<Expr>> {
let mut result = Vec::with_capacity(exprs.len() + schema.len());

for mut expr in exprs {
Expand Down Expand Up @@ -347,11 +366,11 @@ pub(crate) fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema, keys: &[Exp
.find(|e| matches!(e, Expr::Columns(_) | Expr::DtypeColumn(_)))
{
match &e {
Expr::Columns(names) => expand_columns(&expr, &mut result, names),
Expr::Columns(names) => expand_columns(&expr, &mut result, names)?,
Expr::DtypeColumn(dtypes) => {
// keep track of column excluded from the dtypes
let exclude = prepare_excluded(&expr, schema, keys);
expand_dtypes(&expr, &mut result, schema, dtypes, &exclude)
let exclude = prepare_excluded(&expr, schema, keys)?;
expand_dtypes(&expr, &mut result, schema, dtypes, &exclude)?
}
_ => {}
}
Expand All @@ -360,20 +379,20 @@ pub(crate) fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema, keys: &[Exp
// has multiple column names due to wildcards
else if has_wildcard {
// keep track of column excluded from the wildcard
let exclude = prepare_excluded(&expr, schema, keys);
let exclude = prepare_excluded(&expr, schema, keys)?;
// this path prepares the wildcard as input for the Function Expr
replace_wildcard(&expr, &mut result, &exclude, schema);
replace_wildcard(&expr, &mut result, &exclude, schema)?;
}
// can have multiple column names due to a regex
else {
#[allow(clippy::collapsible_else_if)]
#[cfg(feature = "regex")]
{
replace_regex(&expr, &mut result, schema)
replace_regex(&expr, &mut result, schema)?
}
#[cfg(not(feature = "regex"))]
{
let expr = rewrite_special_aliases(expr);
let expr = rewrite_special_aliases(expr)?;
result.push(expr)
}
}
Expand Down Expand Up @@ -404,5 +423,5 @@ pub(crate) fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema, keys: &[Exp
}
}
}
result
Ok(result)
}
9 changes: 7 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1094,8 +1094,13 @@ impl PyExpr {
self.inner
.clone()
.map_alias(move |name| {
let out = Python::with_gil(|py| lambda.call1(py, (name,)).unwrap());
out.to_string()
let out = Python::with_gil(|py| lambda.call1(py, (name,)));
match out {
Ok(out) => Ok(out.to_string()),
Err(e) => Err(PolarsError::ComputeError(
format!("Python function in 'map_alias' produced an error: {}.", e).into(),
)),
}
})
.into()
}
Expand Down

0 comments on commit 90da88e

Please sign in to comment.