Skip to content

Commit

Permalink
Added literal param to string-replace functions, optimized `replace…
Browse files Browse the repository at this point in the history
…` performance in small-string regime (30-80% faster) (#4057)
  • Loading branch information
alexander-beedie committed Jul 17, 2022
1 parent 34b8adf commit 35304d5
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 59 deletions.
64 changes: 39 additions & 25 deletions polars/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use polars_arrow::{
export::arrow::{self, compute::substring::substring},
kernels::string::*,
};
use polars_core::export::regex::Regex;
use polars_core::export::regex::{escape, Regex};
use std::borrow::Cow;

fn f_regex_extract<'a>(reg: &Regex, input: &'a str, group_index: usize) -> Option<Cow<'a, str>> {
Expand Down Expand Up @@ -100,28 +100,19 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
ca.apply(f)
}

/// Check if strings contain a regex pattern; select literal fast-path if no special chars
/// Check if strings contain a regex pattern; take literal fast-path if
/// no special chars and strlen <= 96 chars (otherwise regex faster).
fn contains(&self, pat: &str) -> Result<BooleanChunked> {
if pat.chars().all(|c| !c.is_ascii_punctuation()) {
self.contains_literal(pat)
} else {
let ca = self.as_utf8();
let reg = Regex::new(pat)?;
let f = |s| reg.is_match(s);
let mut out: BooleanChunked = if !ca.has_validity() {
ca.into_no_null_iter().map(f).collect()
} else {
ca.into_iter().map(|opt_s| opt_s.map(f)).collect()
};
out.rename(ca.name());
Ok(out)
}
}

/// Check if strings contain a given literal
fn contains_literal(&self, lit: &str) -> Result<BooleanChunked> {
let lit = pat.chars().all(|c| !c.is_ascii_punctuation());
let ca = self.as_utf8();
let f = |s: &str| s.contains(lit);
let reg = Regex::new(pat)?;
let f = |s: &str| {
if lit && (s.len() <= 96) {
s.contains(pat)
} else {
reg.is_match(s)
}
};
let mut out: BooleanChunked = if !ca.has_validity() {
ca.into_no_null_iter().map(f).collect()
} else {
Expand All @@ -131,6 +122,11 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
Ok(out)
}

/// Check if strings contain a given literal
fn contains_literal(&self, lit: &str) -> Result<BooleanChunked> {
self.contains(escape(lit).as_str())
}

/// Check if strings ends with a substring
fn ends_with(&self, sub: &str) -> BooleanChunked {
let ca = self.as_utf8();
Expand All @@ -149,22 +145,40 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
out
}

/// Replace the leftmost (sub)string by a regex pattern
fn replace(&self, pat: &str, val: &str) -> Result<Utf8Chunked> {
/// Replace the leftmost regex-matched (sub)string with another string; take
/// fast-path for small (<= 32 chars) strings (otherwise regex faster).
fn replace<'a>(&'a self, pat: &str, val: &str) -> Result<Utf8Chunked> {
let lit = pat.chars().all(|c| !c.is_ascii_punctuation());
let ca = self.as_utf8();
let reg = Regex::new(pat)?;
let f = |s| reg.replace(s, val);
let f = |s: &'a str| {
if lit && (s.len() <= 32) {
Cow::Owned(s.replacen(pat, val, 1))
} else {
reg.replace(s, val)
}
};
Ok(ca.apply(f))
}

/// Replace all (sub)strings by a regex pattern
/// Replace the leftmost literal (sub)string with another string
fn replace_literal(&self, pat: &str, val: &str) -> Result<Utf8Chunked> {
self.replace(escape(pat).as_str(), val)
}

/// Replace all regex-matched (sub)strings with another string
fn replace_all(&self, pat: &str, val: &str) -> Result<Utf8Chunked> {
let ca = self.as_utf8();
let reg = Regex::new(pat)?;
let f = |s| reg.replace_all(s, val);
Ok(ca.apply(f))
}

/// Replace all matching literal (sub)strings with another string
fn replace_literal_all(&self, pat: &str, val: &str) -> Result<Utf8Chunked> {
self.replace_all(escape(pat).as_str(), val)
}

/// Extract the nth capture group from pattern
fn extract(&self, pat: &str, group_index: usize) -> Result<Utf8Chunked> {
let ca = self.as_utf8();
Expand Down
20 changes: 12 additions & 8 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6319,20 +6319,22 @@ def split_exact(self, by: str, n: int, inclusive: bool = False) -> Expr:
return wrap_expr(self._pyexpr.str_split_exact_inclusive(by, n))
return wrap_expr(self._pyexpr.str_split_exact(by, n))

def replace(self, pattern: str, value: str) -> Expr:
def replace(self, pattern: str, value: str, literal: bool = False) -> Expr:
"""
Replace first regex match with a string value.
Replace first matching regex/literal substring with a new string value.
Parameters
----------
pattern
Regex pattern.
value
Replacement string.
literal
Treat pattern as a literal string.
See Also
--------
replace_all : Replace substring on all regex pattern matches.
replace_all : Replace all matching regex/literal substrings.
Examples
--------
Expand All @@ -6352,22 +6354,24 @@ def replace(self, pattern: str, value: str) -> Expr:
└─────┴────────┘
"""
return wrap_expr(self._pyexpr.str_replace(pattern, value))
return wrap_expr(self._pyexpr.str_replace(pattern, value, literal))

def replace_all(self, pattern: str, value: str) -> Expr:
def replace_all(self, pattern: str, value: str, literal: bool = False) -> Expr:
"""
Replace substring on all regex pattern matches.
Replace all matching regex/literal substrings with a new string value.
Parameters
----------
pattern
Regex pattern.
value
Replacement string.
literal
Treat pattern as a literal string.
See Also
--------
replace : Replace first regex match with a string value.
replace : Replace first matching regex/literal substring.
Examples
--------
Expand All @@ -6384,7 +6388,7 @@ def replace_all(self, pattern: str, value: str) -> Expr:
│ 2 ┆ 123-123 │
└─────┴─────────┘
"""
return wrap_expr(self._pyexpr.str_replace_all(pattern, value))
return wrap_expr(self._pyexpr.str_replace_all(pattern, value, literal))

def slice(self, start: int, length: int | None = None) -> Expr:
"""
Expand Down
20 changes: 12 additions & 8 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4588,20 +4588,22 @@ def split_exact(self, by: str, n: int, inclusive: bool = False) -> Series:
.to_series()
)

def replace(self, pattern: str, value: str) -> Series:
def replace(self, pattern: str, value: str, literal: bool = False) -> Series:
"""
Replace first regex match with a string value.
Replace first matching regex/literal substring with a new string value.
Parameters
----------
pattern
A valid regex pattern.
value
Substring to replace.
literal
Treat pattern as a literal string.
See Also
--------
replace_all : Replace all regex matches with a string value.
replace_all : Replace all matching regex/literal substrings.
Examples
--------
Expand All @@ -4615,22 +4617,24 @@ def replace(self, pattern: str, value: str) -> Series:
]
"""
return wrap_s(self._s.str_replace(pattern, value))
return wrap_s(self._s.str_replace(pattern, value, literal))

def replace_all(self, pattern: str, value: str) -> Series:
def replace_all(self, pattern: str, value: str, literal: bool = False) -> Series:
"""
Replace all regex matches with a string value.
Replace all matching regex/literal substrings with a new string value.
Parameters
----------
pattern
A valid regex pattern.
value
Substring to replace.
literal
Treat pattern as a literal string.
See Also
--------
replace : Replace first regex match with a string value.
replace : Replace first matching regex/literal substring.
Examples
--------
Expand All @@ -4643,7 +4647,7 @@ def replace_all(self, pattern: str, value: str) -> Series:
"123-123"
]
"""
return wrap_s(self._s.str_replace_all(pattern, value))
return wrap_s(self._s.str_replace_all(pattern, value, literal))

def strip(self) -> Series:
"""
Expand Down
16 changes: 12 additions & 4 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,10 +591,14 @@ impl PyExpr {
.into()
}

pub fn str_replace(&self, pat: String, val: String) -> PyExpr {
pub fn str_replace(&self, pat: String, val: String, literal: Option<bool>) -> PyExpr {
let function = move |s: Series| {
let ca = s.utf8()?;
match ca.replace(&pat, &val) {
let replaced = match literal {
Some(true) => ca.replace_literal(&pat, &val),
_ => ca.replace(&pat, &val),
};
match replaced {
Ok(ca) => Ok(ca.into_series()),
Err(e) => Err(PolarsError::ComputeError(format!("{:?}", e).into())),
}
Expand All @@ -606,10 +610,14 @@ impl PyExpr {
.into()
}

pub fn str_replace_all(&self, pat: String, val: String) -> PyExpr {
pub fn str_replace_all(&self, pat: String, val: String, literal: Option<bool>) -> PyExpr {
let function = move |s: Series| {
let ca = s.utf8()?;
match ca.replace_all(&pat, &val) {
let replaced = match literal {
Some(true) => ca.replace_literal_all(&pat, &val),
_ => ca.replace_all(&pat, &val),
};
match replaced {
Ok(ca) => Ok(ca.into_series()),
Err(e) => Err(PolarsError::ComputeError(format!("{:?}", e).into())),
}
Expand Down
31 changes: 17 additions & 14 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1117,10 +1117,9 @@ impl PySeries {

pub fn str_contains(&self, pat: &str, literal: Option<bool>) -> PyResult<Self> {
let ca = self.series.utf8().map_err(PyPolarsErr::from)?;
let s = if literal.unwrap_or(false) {
ca.contains_literal(pat)
} else {
ca.contains(pat)
let s = match literal {
Some(true) => ca.contains_literal(pat),
_ => ca.contains(pat),
}
.map_err(PyPolarsErr::from)?
.into_series();
Expand All @@ -1145,21 +1144,25 @@ impl PySeries {
Ok(s.into())
}

pub fn str_replace(&self, pat: &str, val: &str) -> PyResult<Self> {
pub fn str_replace(&self, pat: &str, val: &str, literal: Option<bool>) -> PyResult<Self> {
let ca = self.series.utf8().map_err(PyPolarsErr::from)?;
let s = ca
.replace(pat, val)
.map_err(PyPolarsErr::from)?
.into_series();
let s = match literal {
Some(true) => ca.replace_literal(pat, val),
_ => ca.replace(pat, val),
}
.map_err(PyPolarsErr::from)?
.into_series();
Ok(s.into())
}

pub fn str_replace_all(&self, pat: &str, val: &str) -> PyResult<Self> {
pub fn str_replace_all(&self, pat: &str, val: &str, literal: Option<bool>) -> PyResult<Self> {
let ca = self.series.utf8().map_err(PyPolarsErr::from)?;
let s = ca
.replace_all(pat, val)
.map_err(PyPolarsErr::from)?
.into_series();
let s = match literal {
Some(true) => ca.replace_literal_all(pat, val),
_ => ca.replace_all(pat, val),
}
.map_err(PyPolarsErr::from)?
.into_series();
Ok(s.into())
}

Expand Down
69 changes: 69 additions & 0 deletions py-polars/tests/test_strings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pytest

import polars as pl


Expand Down Expand Up @@ -62,6 +64,73 @@ def test_null_comparisons() -> None:
assert (s.shift() != s).null_count() == 0


def test_replace() -> None:
df = pl.DataFrame(
data=[(1, "* * text"), (2, "(with) special\n * chars **etc...?$")],
columns=["idx", "text"],
orient="row",
)
for pattern, replacement, as_literal, expected in (
(r"\*", "-", False, ["- * text", "(with) special\n - chars **etc...?$"]),
(r"*", "-", True, ["- * text", "(with) special\n - chars **etc...?$"]),
(r"^\(", "[", False, ["* * text", "[with) special\n * chars **etc...?$"]),
(r"^\(", "[", True, ["* * text", "(with) special\n * chars **etc...?$"]),
(r"t$", "an", False, ["* * texan", "(with) special\n * chars **etc...?$"]),
(r"t$", "an", True, ["* * text", "(with) special\n * chars **etc...?$"]),
):
# series
assert (
expected
== df["text"]
.str.replace(pattern, replacement, literal=as_literal)
.to_list()
)
# expr
assert (
expected
== df.select(
pl.col("text").str.replace(pattern, replacement, literal=as_literal)
)["text"].to_list()
)


def test_replace_all() -> None:
df = pl.DataFrame(
data=[(1, "* * text"), (2, "(with) special * chars **etc...?$")],
columns=["idx", "text"],
orient="row",
)
for pattern, replacement, as_literal, expected in (
(r"\*", "-", False, ["- - text", "(with) special - chars --etc...?$"]),
(r"*", "-", True, ["- - text", "(with) special - chars --etc...?$"]),
(r"\W", "", False, ["text", "withspecialcharsetc"]),
(r".?$", "", True, ["* * text", "(with) special * chars **etc.."]),
(
r"(\b)[\w\s]{2,}(\b)",
"$1(blah)$3",
False,
["* * (blah)", "((blah)) (blah) * (blah) **(blah)...?$"],
),
):
# series
assert (
expected
== df["text"]
.str.replace_all(pattern, replacement, literal=as_literal)
.to_list()
)
# expr
assert (
expected
== df.select(
pl.col("text").str.replace_all(pattern, replacement, literal=as_literal)
)["text"].to_list()
)
# invalid regex (but valid literal - requires "literal=True")
with pytest.raises(pl.ComputeError):
df["text"].str.replace_all("*", "")


def test_extract_all_count() -> None:
df = pl.DataFrame({"foo": ["123 bla 45 asd", "xyz 678 910t"]})
assert (
Expand Down

0 comments on commit 35304d5

Please sign in to comment.