Skip to content

Commit

Permalink
feat: Add str.head and str.tail (#14425)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Apr 13, 2024
1 parent e5a3620 commit 429d3dd
Show file tree
Hide file tree
Showing 11 changed files with 591 additions and 17 deletions.
23 changes: 23 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Expand Up @@ -612,6 +612,29 @@ pub trait StringNameSpaceImpl: AsString {

Ok(substring::substring(ca, offset.i64()?, length.u64()?))
}

/// Slice the first `n` values of the string.
///
/// Determines a substring starting at the beginning of the string up to offset `n` of each
/// element in `array`. `n` can be negative, in which case the slice ends `n` characters from
/// the end of the string.
fn str_head(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.strict_cast(&DataType::Int64)?;

Ok(substring::head(ca, n.i64()?))
}

/// Slice the last `n` values of the string.
///
/// Determines a substring starting at offset `n` of each element in `array`. `n` can be
/// negative, in which case the slice begins `n` characters from the start of the string.
fn str_tail(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.strict_cast(&DataType::Int64)?;

Ok(substring::tail(ca, n.i64()?))
}
}

impl StringNameSpaceImpl for StringChunked {}
112 changes: 105 additions & 7 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
@@ -1,6 +1,72 @@
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked};

fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
// `max_len` is guaranteed to be at least the total number of characters.
let max_len = str_val.len();
if n == 0 {
Some("")
} else {
let end_idx = if n > 0 {
if n as usize >= max_len {
return opt_str_val;
}
// End after the nth codepoint.
str_val
.char_indices()
.nth(n as usize)
.map(|(idx, _)| idx)
.unwrap_or(max_len)
} else {
// End after the nth codepoint from the end.
str_val
.char_indices()
.rev()
.nth((-n - 1) as usize)
.map(|(idx, _)| idx)
.unwrap_or(0)
};
Some(&str_val[..end_idx])
}
} else {
None
}
}

fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
// `max_len` is guaranteed to be at least the total number of characters.
let max_len = str_val.len();
if n == 0 {
Some("")
} else {
let start_idx = if n > 0 {
if n as usize >= max_len {
return opt_str_val;
}
// Start from nth codepoint from the end
str_val
.char_indices()
.rev()
.nth((n - 1) as usize)
.map(|(idx, _)| idx)
.unwrap_or(0)
} else {
// Start after the nth codepoint
str_val
.char_indices()
.nth((-n) as usize)
.map(|(idx, _)| idx)
.unwrap_or(max_len)
};
Some(&str_val[start_idx..])
}
} else {
None
}
}

fn substring_ternary(
opt_str_val: Option<&str>,
opt_offset: Option<i64>,
Expand Down Expand Up @@ -54,30 +120,30 @@ pub(super) fn substring(
) -> StringChunked {
match (ca.len(), offset.len(), length.len()) {
(1, 1, _) => {
// SAFETY: index `0` is in bound.
// SAFETY: `ca` was verified to have least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
// SAFETY: `offset` was verified to have at least 1 element.
let offset = unsafe { offset.get_unchecked(0) };
unary_elementwise(length, |length| substring_ternary(str_val, offset, length))
.with_name(ca.name())
},
(_, 1, 1) => {
// SAFETY: index `0` is in bound.
// SAFETY: `offset` was verified to have at least 1 element.
let offset = unsafe { offset.get_unchecked(0) };
// SAFETY: index `0` is in bound.
// SAFETY: `length` was verified to have at least 1 element.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(ca, |str_val| substring_ternary(str_val, offset, length))
},
(1, _, 1) => {
// SAFETY: index `0` is in bound.
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
// SAFETY: `length` was verified to have at least 1 element.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length))
.with_name(ca.name())
},
(1, len_b, len_c) if len_b == len_c => {
// SAFETY: index `0` is in bound.
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
binary_elementwise(offset, length, |offset, length| {
substring_ternary(str_val, offset, length)
Expand Down Expand Up @@ -112,3 +178,35 @@ pub(super) fn substring(
_ => ternary_elementwise(ca, offset, length, substring_ternary),
}
}

pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: `n` was verified to have at least 1 element.
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| head_binary(str_val, n)).with_name(ca.name())
},
(1, _) => {
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, head_binary),
}
}

pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: `n` was verified to have at least 1 element.
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| tail_binary(str_val, n)).with_name(ca.name())
},
(1, _) => {
// SAFETY: `ca` was verified to have at least 1 element.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, tail_binary),
}
}
40 changes: 36 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Expand Up @@ -82,6 +82,8 @@ pub enum StringFunction {
fill_char: char,
},
Slice,
Head,
Tail,
#[cfg(feature = "string_encoding")]
HexEncode,
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -166,7 +168,7 @@ impl StringFunction {
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => mapper.with_dtype(DataType::Binary),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
| StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -210,6 +212,8 @@ impl Display for StringFunction {
ToInteger { .. } => "to_integer",
#[cfg(feature = "regex")]
Find { .. } => "find",
Head { .. } => "head",
Tail { .. } => "tail",
#[cfg(feature = "extract_jsonpath")]
JsonDecode { .. } => "json_decode",
LenBytes => "len_bytes",
Expand Down Expand Up @@ -345,6 +349,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "string_to_integer")]
ToInteger(strict) => map_as_slice!(strings::to_integer, strict),
Slice => map_as_slice!(strings::str_slice),
Head => map_as_slice!(strings::str_head),
Tail => map_as_slice!(strings::str_tail),
#[cfg(feature = "string_encoding")]
HexEncode => map!(strings::hex_encode),
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -894,24 +900,50 @@ pub(super) fn to_integer(s: &[Series], strict: bool) -> PolarsResult<Series> {
ca.to_integer(base.u32()?, strict)
.map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {

fn _ensure_lengths(s: &[Series]) -> bool {
// Calculate the post-broadcast length and ensure everything is consistent.
let len = s
.iter()
.map(|series| series.len())
.filter(|l| *l != 1)
.max()
.unwrap_or(1);
s.iter()
.all(|series| series.len() == 1 || series.len() == len)
}

pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
s.iter().all(|series| series.len() == 1 || series.len() == len),
ComputeError: "all series in `str_slice` should have equal or unit length"
_ensure_lengths(s),
ComputeError: "all series in `str_slice` should have equal or unit length",
);
let ca = s[0].str()?;
let offset = &s[1];
let length = &s[2];
Ok(ca.str_slice(offset, length)?.into_series())
}

pub(super) fn str_head(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_head` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_head(n)?.into_series())
}

pub(super) fn str_tail(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_tail` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_tail(n)?.into_series())
}

#[cfg(feature = "string_encoding")]
pub(super) fn hex_encode(s: &Series) -> PolarsResult<Series> {
Ok(s.str()?.hex_encode().into_series())
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Expand Up @@ -527,6 +527,26 @@ impl StringNameSpace {
)
}

/// Take the first `n` characters of the string values.
pub fn head(self, n: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Head),
&[n],
false,
false,
)
}

/// Take the last `n` characters of the string values.
pub fn tail(self, n: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Tail),
&[n],
false,
false,
)
}

pub fn explode(self) -> Expr {
self.0
.apply_private(FunctionExpr::StringExpr(StringFunction::Explode))
Expand Down
4 changes: 3 additions & 1 deletion py-polars/docs/source/reference/expressions/string.rst
Expand Up @@ -22,6 +22,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.extract_all
Expr.str.extract_groups
Expr.str.find
Expr.str.head
Expr.str.json_decode
Expr.str.json_extract
Expr.str.json_path_match
Expand All @@ -33,6 +34,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.n_chars
Expr.str.pad_end
Expr.str.pad_start
Expr.str.parse_int
Expr.str.replace
Expr.str.replace_all
Expr.str.replace_many
Expand All @@ -51,6 +53,7 @@ The following methods are available under the `expr.str` attribute.
Expr.str.strip_prefix
Expr.str.strip_suffix
Expr.str.strptime
Expr.str.tail
Expr.str.to_date
Expr.str.to_datetime
Expr.str.to_decimal
Expand All @@ -60,4 +63,3 @@ The following methods are available under the `expr.str` attribute.
Expr.str.to_time
Expr.str.to_uppercase
Expr.str.zfill
Expr.str.parse_int
4 changes: 3 additions & 1 deletion py-polars/docs/source/reference/series/string.rst
Expand Up @@ -22,6 +22,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.extract_all
Series.str.extract_groups
Series.str.find
Series.str.head
Series.str.json_decode
Series.str.json_extract
Series.str.json_path_match
Expand All @@ -33,6 +34,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.n_chars
Series.str.pad_end
Series.str.pad_start
Series.str.parse_int
Series.str.replace
Series.str.replace_all
Series.str.replace_many
Expand All @@ -51,6 +53,7 @@ The following methods are available under the `Series.str` attribute.
Series.str.strip_prefix
Series.str.strip_suffix
Series.str.strptime
Series.str.tail
Series.str.to_date
Series.str.to_datetime
Series.str.to_decimal
Expand All @@ -60,4 +63,3 @@ The following methods are available under the `Series.str` attribute.
Series.str.to_titlecase
Series.str.to_uppercase
Series.str.zfill
Series.str.parse_int

0 comments on commit 429d3dd

Please sign in to comment.