Skip to content

Commit

Permalink
feat[sql]: trim function (#4973)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 25, 2022
1 parent 548984b commit 4020a4c
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 98 deletions.
1 change: 1 addition & 0 deletions polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ binary = ["clap"]

[dependencies]
clap = { version = "3.2.22", features = ["derive"], optional = true }
polars-arrow = { path = "../polars/polars-arrow", features = ["like"] }
polars-lazy = { path = "../polars/polars-lazy", features = ["compile"] }
serde = "1"
serde_json = { version = "1" }
Expand Down
78 changes: 45 additions & 33 deletions polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use polars::error::PolarsError;
use polars::prelude::*;
use sqlparser::ast::{
BinaryOperator as SQLBinaryOperator, BinaryOperator, DataType as SQLDataType, Expr as SqlExpr,
Function as SQLFunction, JoinConstraint, Value as SqlValue, WindowSpec,
Function as SQLFunction, JoinConstraint, TrimWhereField, Value as SqlValue, WindowSpec,
};

fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<DataType> {
Expand Down Expand Up @@ -108,6 +108,16 @@ fn literal_expr(value: &SqlValue) -> PolarsResult<Expr> {
}

pub(crate) fn parse_sql_expr(expr: &SqlExpr) -> PolarsResult<Expr> {
let err = || {
Err(PolarsError::ComputeError(
format!(
"Expression: {:?} was not supported in polars-sql yet!",
expr
)
.into(),
))
};

Ok(match expr {
SqlExpr::Identifier(e) => col(&e.value),
SqlExpr::BinaryOp { left, op, right } => {
Expand Down Expand Up @@ -137,15 +147,41 @@ pub(crate) fn parse_sql_expr(expr: &SqlExpr) -> PolarsResult<Expr> {
expr.clone().gt(low).and(expr.lt(high))
}
}
_ => {
return Err(PolarsError::ComputeError(
format!(
"Expression: {:?} was not supported in polars-sql yet!",
expr
)
.into(),
))
SqlExpr::Trim {
expr: sql_expr,
trim_where,
} => {
let expr = parse_sql_expr(sql_expr)?;
match trim_where {
None => return Ok(expr.str().strip(None)),
Some((TrimWhereField::Both, sql_expr)) => {
let lit = parse_sql_expr(sql_expr)?;
if let Expr::Literal(LiteralValue::Utf8(val)) = lit {
if val.len() == 1 {
return Ok(expr.str().strip(Some(val.chars().next().unwrap())));
}
}
}
Some((TrimWhereField::Leading, sql_expr)) => {
let lit = parse_sql_expr(sql_expr)?;
if let Expr::Literal(LiteralValue::Utf8(val)) = lit {
if val.len() == 1 {
return Ok(expr.str().lstrip(Some(val.chars().next().unwrap())));
}
}
}
Some((TrimWhereField::Trailing, sql_expr)) => {
let lit = parse_sql_expr(sql_expr)?;
if let Expr::Literal(LiteralValue::Utf8(val)) = lit {
if val.len() == 1 {
return Ok(expr.str().rstrip(Some(val.chars().next().unwrap())));
}
}
}
}
return err();
}
_ => return err(),
})
}

Expand Down Expand Up @@ -248,30 +284,6 @@ pub(super) fn process_join_constraint(
}
_ => {}
}
// if let (SqlExpr::CompoundIdentifier(left), SqlExpr::CompoundIdentifier(right)) = (left.as_ref(), right.as_ref()) {
// if left.len() == 2 && right.len() == 2 {
// let tbl_a = &left[0].value;
// let col_a = &left[1].value;
//
// let tbl_b = &right[0].value;
// let col_b = &right[1].value;
//
// if let BinaryOperator::Eq = op {
// if left_name == tbl_a && right_name == tbl_b {
// return Ok((col(col_a), col(col_b)))
// } else if left_name == tbl_b && right_name == tbl_a {
// return Ok((col(col_b), col(col_a)))
// }
// }
// }
// }
// };
//
// let expr = parse_sql_expr(expr)?;
// if let Expr::BinaryExpr { left, right, op } = expr {
// if let Operator::Eq = op {
// return Ok((*left.clone(), *right.clone()));
// }
}
}
Err(PolarsError::ComputeError(
Expand Down
1 change: 1 addition & 0 deletions polars/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ compute = ["arrow/compute_cast"]
temporal = ["arrow/compute_temporal"]
bigidx = []
performant = []
like = ["arrow/compute_like"]
3 changes: 3 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Replace { all, literal } => map_as_slice!(strings::replace, literal, all),
Uppercase => map!(strings::uppercase),
Lowercase => map!(strings::lowercase),
Strip(matches) => map!(strings::strip, matches),
LStrip(matches) => map!(strings::lstrip, matches),
RStrip(matches) => map!(strings::rstrip, matches),
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion polars/polars-lazy/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ impl FunctionExpr {
ConcatVertical(_) | ConcatHorizontal(_) => with_dtype(DataType::Utf8),
#[cfg(feature = "regex")]
Replace { .. } => with_dtype(DataType::Utf8),
Uppercase | Lowercase => with_dtype(DataType::Utf8),
Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) => {
with_dtype(DataType::Utf8)
}
}
}
#[cfg(feature = "temporal")]
Expand Down
76 changes: 59 additions & 17 deletions polars/polars-lazy/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,35 +49,43 @@ pub enum StringFunction {
},
Uppercase,
Lowercase,
Strip(Option<char>),
RStrip(Option<char>),
LStrip(Option<char>),
}

impl Display for StringFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use self::*;
match self {
StringFunction::Contains { .. } => write!(f, "str.contains"),
StringFunction::StartsWith(_) => write!(f, "str.starts_with"),
StringFunction::EndsWith(_) => write!(f, "str.ends_with"),
StringFunction::Extract { .. } => write!(f, "str.extract"),
let s = match self {
StringFunction::Contains { .. } => "contains",
StringFunction::StartsWith(_) => "starts_with",
StringFunction::EndsWith(_) => "ends_with",
StringFunction::Extract { .. } => "extract",
#[cfg(feature = "string_justify")]
StringFunction::Zfill(_) => write!(f, "str.zfill"),
StringFunction::Zfill(_) => "zfill",
#[cfg(feature = "string_justify")]
StringFunction::LJust { .. } => write!(f, "str.ljust"),
StringFunction::LJust { .. } => "str.ljust",
#[cfg(feature = "string_justify")]
StringFunction::RJust { .. } => write!(f, "str.rjust"),
StringFunction::ExtractAll(_) => write!(f, "str.extract_all"),
StringFunction::CountMatch(_) => write!(f, "str.count_match"),
StringFunction::RJust { .. } => "rjust",
StringFunction::ExtractAll(_) => "extract_all",
StringFunction::CountMatch(_) => "count_match",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_) => write!(f, "str.strptime"),
StringFunction::Strptime(_) => "strptime",
#[cfg(feature = "concat_str")]
StringFunction::ConcatVertical(_) => write!(f, "str.concat_vertical"),
StringFunction::ConcatVertical(_) => "concat_vertical",
#[cfg(feature = "concat_str")]
StringFunction::ConcatHorizontal(_) => write!(f, "str.concat_horizontal"),
StringFunction::ConcatHorizontal(_) => "concat_horizontal",
#[cfg(feature = "regex")]
StringFunction::Replace { .. } => write!(f, "str.replace"),
StringFunction::Uppercase => write!(f, "str.uppercase"),
StringFunction::Lowercase => write!(f, "str.lowercase"),
}
StringFunction::Replace { .. } => "replace",
StringFunction::Uppercase => "uppercase",
StringFunction::Lowercase => "lowercase",
StringFunction::Strip(_) => "strip",
StringFunction::LStrip(_) => "lstrip",
StringFunction::RStrip(_) => "rstrip",
};

write!(f, "str.{}", s)
}
}

Expand Down Expand Up @@ -134,6 +142,40 @@ pub(super) fn rjust(s: &Series, width: usize, fillchar: char) -> PolarsResult<Se
Ok(ca.rjust(width, fillchar).into_series())
}

pub(super) fn strip(s: &Series, matches: Option<char>) -> PolarsResult<Series> {
let ca = s.utf8()?;
if let Some(matches) = matches {
Ok(ca
.apply(|s| Cow::Borrowed(s.trim_matches(matches)))
.into_series())
} else {
Ok(ca.apply(|s| Cow::Borrowed(s.trim())).into_series())
}
}

pub(super) fn lstrip(s: &Series, matches: Option<char>) -> PolarsResult<Series> {
let ca = s.utf8()?;

if let Some(matches) = matches {
Ok(ca
.apply(|s| Cow::Borrowed(s.trim_start_matches(matches)))
.into_series())
} else {
Ok(ca.apply(|s| Cow::Borrowed(s.trim_start())).into_series())
}
}

pub(super) fn rstrip(s: &Series, matches: Option<char>) -> PolarsResult<Series> {
let ca = s.utf8()?;
if let Some(matches) = matches {
Ok(ca
.apply(|s| Cow::Borrowed(s.trim_end_matches(matches)))
.into_series())
} else {
Ok(ca.apply(|s| Cow::Borrowed(s.trim_end())).into_series())
}
}

pub(super) fn extract_all(s: &Series, pat: &str) -> PolarsResult<Series> {
let pat = pat.to_string();

Expand Down
18 changes: 18 additions & 0 deletions polars/polars-lazy/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,24 @@ impl StringNameSpace {
)
}

/// Remove whitespace on both sides.
pub fn strip(self, matches: Option<char>) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Strip(matches)))
}

/// Remove leading whitespace.
pub fn lstrip(self, matches: Option<char>) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::LStrip(matches)))
}

/// Remove trailing whitespace.
pub fn rstrip(self, matches: Option<char>) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::RStrip(matches)))
}

/// Convert all characters to lowercase.
pub fn to_lowercase(self) -> Expr {
self.0
Expand Down
3 changes: 3 additions & 0 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 22 additions & 6 deletions py-polars/polars/internals/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,15 @@ def to_lowercase(self) -> pli.Expr:
"""
return pli.wrap_expr(self._pyexpr.str_to_lowercase())

def strip(self) -> pli.Expr:
def strip(self, matches: None | str = None) -> pli.Expr:
"""
Remove leading and trailing whitespace.
Parameters
----------
matches
An optional single character that should be trimmed
Examples
--------
>>> df = pl.DataFrame({"foo": [" lead", "trail ", " both "]})
Expand All @@ -228,9 +233,11 @@ def strip(self) -> pli.Expr:
└───────┘
"""
return pli.wrap_expr(self._pyexpr.str_strip())
if matches is not None and len(matches) > 1:
raise ValueError("matches should contain a single character")
return pli.wrap_expr(self._pyexpr.str_strip(matches))

def lstrip(self) -> pli.Expr:
def lstrip(self, matches: None | str = None) -> pli.Expr:
"""
Remove leading whitespace.
Expand All @@ -252,12 +259,19 @@ def lstrip(self) -> pli.Expr:
└────────┘
"""
return pli.wrap_expr(self._pyexpr.str_lstrip())
if matches is not None and len(matches) > 1:
raise ValueError("matches should contain a single character")
return pli.wrap_expr(self._pyexpr.str_lstrip(matches))

def rstrip(self) -> pli.Expr:
def rstrip(self, matches: None | str = None) -> pli.Expr:
"""
Remove trailing whitespace.
Parameters
----------
matches
An optional single character that should be trimmed
Examples
--------
>>> df = pl.DataFrame({"foo": [" lead", "trail ", " both "]})
Expand All @@ -276,7 +290,9 @@ def rstrip(self) -> pli.Expr:
└───────┘
"""
return pli.wrap_expr(self._pyexpr.str_rstrip())
if matches is not None and len(matches) > 1:
raise ValueError("matches should contain a single character")
return pli.wrap_expr(self._pyexpr.str_rstrip(matches))

def zfill(self, alignment: int) -> pli.Expr:
"""
Expand Down
36 changes: 30 additions & 6 deletions py-polars/polars/internals/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,14 +626,38 @@ def replace_all(
"""

def strip(self) -> pli.Series:
"""Remove leading and trailing whitespace."""
def strip(self, matches: None | str = None) -> pli.Series:
"""
Remove leading and trailing whitespace.
Parameters
----------
matches
An optional single character that should be trimmed
"""

def lstrip(self, matches: None | str = None) -> pli.Series:
"""
Remove leading whitespace.
Parameters
----------
matches
An optional single character that should be trimmed
def lstrip(self) -> pli.Series:
"""Remove leading whitespace."""
"""

def rstrip(self, matches: None | str = None) -> pli.Series:
"""
Remove trailing whitespace.
def rstrip(self) -> pli.Series:
"""Remove trailing whitespace."""
Parameters
----------
matches
An optional single character that should be trimmed
"""

def zfill(self, alignment: int) -> pli.Series:
"""
Expand Down

0 comments on commit 4020a4c

Please sign in to comment.