From 429d3ddc16b3ef6fce7b6425c8024bfe85d5a880 Mon Sep 17 00:00:00 2001 From: Marshall Date: Sat, 13 Apr 2024 11:02:01 -0400 Subject: [PATCH] feat: Add `str.head` and `str.tail` (#14425) --- .../src/chunked_array/strings/namespace.rs | 23 +++ .../src/chunked_array/strings/substring.rs | 112 +++++++++++++- .../src/dsl/function_expr/strings.rs | 40 ++++- crates/polars-plan/src/dsl/string.rs | 20 +++ .../source/reference/expressions/string.rst | 4 +- .../docs/source/reference/series/string.rst | 4 +- py-polars/polars/expr/string.py | 146 ++++++++++++++++++ py-polars/polars/series/string.py | 120 ++++++++++++++ py-polars/src/expr/general.rs | 9 +- py-polars/src/expr/string.rs | 8 + .../unit/namespaces/string/test_string.py | 122 +++++++++++++++ 11 files changed, 591 insertions(+), 17 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 3b709a7c8de6..637336f00ea2 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -612,6 +612,29 @@ pub trait StringNameSpaceImpl: AsString { Ok(substring::substring(ca, offset.i64()?, length.u64()?)) } + + /// Slice the first `n` values of the string. + /// + /// Determines a substring starting at the beginning of the string up to offset `n` of each + /// element in `array`. `n` can be negative, in which case the slice ends `n` characters from + /// the end of the string. + fn str_head(&self, n: &Series) -> PolarsResult { + let ca = self.as_string(); + let n = n.strict_cast(&DataType::Int64)?; + + Ok(substring::head(ca, n.i64()?)) + } + + /// Slice the last `n` values of the string. + /// + /// Determines a substring starting at offset `n` of each element in `array`. `n` can be + /// negative, in which case the slice begins `n` characters from the start of the string. + fn str_tail(&self, n: &Series) -> PolarsResult { + let ca = self.as_string(); + let n = n.strict_cast(&DataType::Int64)?; + + Ok(substring::tail(ca, n.i64()?)) + } } impl StringNameSpaceImpl for StringChunked {} diff --git a/crates/polars-ops/src/chunked_array/strings/substring.rs b/crates/polars-ops/src/chunked_array/strings/substring.rs index b2e69b57317f..43bddd8d10ca 100644 --- a/crates/polars-ops/src/chunked_array/strings/substring.rs +++ b/crates/polars-ops/src/chunked_array/strings/substring.rs @@ -1,6 +1,72 @@ use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise}; use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked}; +fn head_binary(opt_str_val: Option<&str>, opt_n: Option) -> Option<&str> { + if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) { + // `max_len` is guaranteed to be at least the total number of characters. + let max_len = str_val.len(); + if n == 0 { + Some("") + } else { + let end_idx = if n > 0 { + if n as usize >= max_len { + return opt_str_val; + } + // End after the nth codepoint. + str_val + .char_indices() + .nth(n as usize) + .map(|(idx, _)| idx) + .unwrap_or(max_len) + } else { + // End after the nth codepoint from the end. + str_val + .char_indices() + .rev() + .nth((-n - 1) as usize) + .map(|(idx, _)| idx) + .unwrap_or(0) + }; + Some(&str_val[..end_idx]) + } + } else { + None + } +} + +fn tail_binary(opt_str_val: Option<&str>, opt_n: Option) -> Option<&str> { + if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) { + // `max_len` is guaranteed to be at least the total number of characters. + let max_len = str_val.len(); + if n == 0 { + Some("") + } else { + let start_idx = if n > 0 { + if n as usize >= max_len { + return opt_str_val; + } + // Start from nth codepoint from the end + str_val + .char_indices() + .rev() + .nth((n - 1) as usize) + .map(|(idx, _)| idx) + .unwrap_or(0) + } else { + // Start after the nth codepoint + str_val + .char_indices() + .nth((-n) as usize) + .map(|(idx, _)| idx) + .unwrap_or(max_len) + }; + Some(&str_val[start_idx..]) + } + } else { + None + } +} + fn substring_ternary( opt_str_val: Option<&str>, opt_offset: Option, @@ -54,30 +120,30 @@ pub(super) fn substring( ) -> StringChunked { match (ca.len(), offset.len(), length.len()) { (1, 1, _) => { - // SAFETY: index `0` is in bound. + // SAFETY: `ca` was verified to have least 1 element. let str_val = unsafe { ca.get_unchecked(0) }; - // SAFETY: index `0` is in bound. + // SAFETY: `offset` was verified to have at least 1 element. let offset = unsafe { offset.get_unchecked(0) }; unary_elementwise(length, |length| substring_ternary(str_val, offset, length)) .with_name(ca.name()) }, (_, 1, 1) => { - // SAFETY: index `0` is in bound. + // SAFETY: `offset` was verified to have at least 1 element. let offset = unsafe { offset.get_unchecked(0) }; - // SAFETY: index `0` is in bound. + // SAFETY: `length` was verified to have at least 1 element. let length = unsafe { length.get_unchecked(0) }; unary_elementwise(ca, |str_val| substring_ternary(str_val, offset, length)) }, (1, _, 1) => { - // SAFETY: index `0` is in bound. + // SAFETY: `ca` was verified to have at least 1 element. let str_val = unsafe { ca.get_unchecked(0) }; - // SAFETY: index `0` is in bound. + // SAFETY: `length` was verified to have at least 1 element. let length = unsafe { length.get_unchecked(0) }; unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length)) .with_name(ca.name()) }, (1, len_b, len_c) if len_b == len_c => { - // SAFETY: index `0` is in bound. + // SAFETY: `ca` was verified to have at least 1 element. let str_val = unsafe { ca.get_unchecked(0) }; binary_elementwise(offset, length, |offset, length| { substring_ternary(str_val, offset, length) @@ -112,3 +178,35 @@ pub(super) fn substring( _ => ternary_elementwise(ca, offset, length, substring_ternary), } } + +pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> StringChunked { + match (ca.len(), n.len()) { + (_, 1) => { + // SAFETY: `n` was verified to have at least 1 element. + let n = unsafe { n.get_unchecked(0) }; + unary_elementwise(ca, |str_val| head_binary(str_val, n)).with_name(ca.name()) + }, + (1, _) => { + // SAFETY: `ca` was verified to have at least 1 element. + let str_val = unsafe { ca.get_unchecked(0) }; + unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name()) + }, + _ => binary_elementwise(ca, n, head_binary), + } +} + +pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> StringChunked { + match (ca.len(), n.len()) { + (_, 1) => { + // SAFETY: `n` was verified to have at least 1 element. + let n = unsafe { n.get_unchecked(0) }; + unary_elementwise(ca, |str_val| tail_binary(str_val, n)).with_name(ca.name()) + }, + (1, _) => { + // SAFETY: `ca` was verified to have at least 1 element. + let str_val = unsafe { ca.get_unchecked(0) }; + unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name()) + }, + _ => binary_elementwise(ca, n, tail_binary), + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 62f81865c22c..4c1c77304bea 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -82,6 +82,8 @@ pub enum StringFunction { fill_char: char, }, Slice, + Head, + Tail, #[cfg(feature = "string_encoding")] HexEncode, #[cfg(feature = "binary_encoding")] @@ -166,7 +168,7 @@ impl StringFunction { #[cfg(feature = "binary_encoding")] Base64Decode(_) => mapper.with_dtype(DataType::Binary), Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix - | StripSuffix | Slice => mapper.with_same_dtype(), + | StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(), #[cfg(feature = "string_pad")] PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(), #[cfg(feature = "dtype-struct")] @@ -210,6 +212,8 @@ impl Display for StringFunction { ToInteger { .. } => "to_integer", #[cfg(feature = "regex")] Find { .. } => "find", + Head { .. } => "head", + Tail { .. } => "tail", #[cfg(feature = "extract_jsonpath")] JsonDecode { .. } => "json_decode", LenBytes => "len_bytes", @@ -345,6 +349,8 @@ impl From for SpecialEq> { #[cfg(feature = "string_to_integer")] ToInteger(strict) => map_as_slice!(strings::to_integer, strict), Slice => map_as_slice!(strings::str_slice), + Head => map_as_slice!(strings::str_head), + Tail => map_as_slice!(strings::str_tail), #[cfg(feature = "string_encoding")] HexEncode => map!(strings::hex_encode), #[cfg(feature = "binary_encoding")] @@ -894,7 +900,8 @@ pub(super) fn to_integer(s: &[Series], strict: bool) -> PolarsResult { ca.to_integer(base.u32()?, strict) .map(|ok| ok.into_series()) } -pub(super) fn str_slice(s: &[Series]) -> PolarsResult { + +fn _ensure_lengths(s: &[Series]) -> bool { // Calculate the post-broadcast length and ensure everything is consistent. let len = s .iter() @@ -902,9 +909,14 @@ pub(super) fn str_slice(s: &[Series]) -> PolarsResult { .filter(|l| *l != 1) .max() .unwrap_or(1); + s.iter() + .all(|series| series.len() == 1 || series.len() == len) +} + +pub(super) fn str_slice(s: &[Series]) -> PolarsResult { polars_ensure!( - s.iter().all(|series| series.len() == 1 || series.len() == len), - ComputeError: "all series in `str_slice` should have equal or unit length" + _ensure_lengths(s), + ComputeError: "all series in `str_slice` should have equal or unit length", ); let ca = s[0].str()?; let offset = &s[1]; @@ -912,6 +924,26 @@ pub(super) fn str_slice(s: &[Series]) -> PolarsResult { Ok(ca.str_slice(offset, length)?.into_series()) } +pub(super) fn str_head(s: &[Series]) -> PolarsResult { + polars_ensure!( + _ensure_lengths(s), + ComputeError: "all series in `str_head` should have equal or unit length", + ); + let ca = s[0].str()?; + let n = &s[1]; + Ok(ca.str_head(n)?.into_series()) +} + +pub(super) fn str_tail(s: &[Series]) -> PolarsResult { + polars_ensure!( + _ensure_lengths(s), + ComputeError: "all series in `str_tail` should have equal or unit length", + ); + let ca = s[0].str()?; + let n = &s[1]; + Ok(ca.str_tail(n)?.into_series()) +} + #[cfg(feature = "string_encoding")] pub(super) fn hex_encode(s: &Series) -> PolarsResult { Ok(s.str()?.hex_encode().into_series()) diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 15c3db4cc463..e462caa71bd7 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -527,6 +527,26 @@ impl StringNameSpace { ) } + /// Take the first `n` characters of the string values. + pub fn head(self, n: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::Head), + &[n], + false, + false, + ) + } + + /// Take the last `n` characters of the string values. + pub fn tail(self, n: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::Tail), + &[n], + false, + false, + ) + } + pub fn explode(self) -> Expr { self.0 .apply_private(FunctionExpr::StringExpr(StringFunction::Explode)) diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index 831edce162b0..1b4159814b81 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -22,6 +22,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.extract_all Expr.str.extract_groups Expr.str.find + Expr.str.head Expr.str.json_decode Expr.str.json_extract Expr.str.json_path_match @@ -33,6 +34,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.n_chars Expr.str.pad_end Expr.str.pad_start + Expr.str.parse_int Expr.str.replace Expr.str.replace_all Expr.str.replace_many @@ -51,6 +53,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.strip_prefix Expr.str.strip_suffix Expr.str.strptime + Expr.str.tail Expr.str.to_date Expr.str.to_datetime Expr.str.to_decimal @@ -60,4 +63,3 @@ The following methods are available under the `expr.str` attribute. Expr.str.to_time Expr.str.to_uppercase Expr.str.zfill - Expr.str.parse_int diff --git a/py-polars/docs/source/reference/series/string.rst b/py-polars/docs/source/reference/series/string.rst index fbbe261e92f7..93bc038619bb 100644 --- a/py-polars/docs/source/reference/series/string.rst +++ b/py-polars/docs/source/reference/series/string.rst @@ -22,6 +22,7 @@ The following methods are available under the `Series.str` attribute. Series.str.extract_all Series.str.extract_groups Series.str.find + Series.str.head Series.str.json_decode Series.str.json_extract Series.str.json_path_match @@ -33,6 +34,7 @@ The following methods are available under the `Series.str` attribute. Series.str.n_chars Series.str.pad_end Series.str.pad_start + Series.str.parse_int Series.str.replace Series.str.replace_all Series.str.replace_many @@ -51,6 +53,7 @@ The following methods are available under the `Series.str` attribute. Series.str.strip_prefix Series.str.strip_suffix Series.str.strptime + Series.str.tail Series.str.to_date Series.str.to_datetime Series.str.to_decimal @@ -60,4 +63,3 @@ The following methods are available under the `Series.str` attribute. Series.str.to_titlecase Series.str.to_uppercase Series.str.zfill - Series.str.parse_int diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 034148d1d4fd..8de24511c00a 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -2205,6 +2205,152 @@ def slice( length = parse_as_expression(length) return wrap_expr(self._pyexpr.str_slice(offset, length)) + def head(self, n: int | IntoExprColumn) -> Expr: + """ + Return the first n characters of each string in a String Series. + + Parameters + ---------- + n + Length of the slice (integer or expression). Negative indexing is supported; + see note (2) below. + + Returns + ------- + Expr + Expression of data type :class:`String`. + + Notes + ----- + 1) The `n` input is defined in terms of the number of characters in the (UTF8) + string. A character is defined as a `Unicode scalar value`_. A single + character is represented by a single byte when working with ASCII text, and a + maximum of 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value + + 2) When the `n` input is negative, `head` returns characters up to the `n`th + from the end of the string. For example, if `n = -3`, then all characters + except the last three are returned. + + 3) If the length of the string has fewer than `n` characters, the full string is + returned. + + Examples + -------- + Return up to the first 5 characters: + + >>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]}) + >>> df.with_columns(pl.col("s").str.head(5).alias("s_head_5")) + shape: (4, 2) + ┌─────────────┬──────────┐ + │ s ┆ s_head_5 │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════════════╪══════════╡ + │ pear ┆ pear │ + │ null ┆ null │ + │ papaya ┆ papay │ + │ dragonfruit ┆ drago │ + └─────────────┴──────────┘ + + Return characters determined by column `n`: + + >>> df = pl.DataFrame( + ... { + ... "s": ["pear", None, "papaya", "dragonfruit"], + ... "n": [3, 4, -2, -5], + ... } + ... ) + >>> df.with_columns(pl.col("s").str.head("n").alias("s_head_n")) + shape: (4, 3) + ┌─────────────┬─────┬──────────┐ + │ s ┆ n ┆ s_head_n │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ str │ + ╞═════════════╪═════╪══════════╡ + │ pear ┆ 3 ┆ pea │ + │ null ┆ 4 ┆ null │ + │ papaya ┆ -2 ┆ papa │ + │ dragonfruit ┆ -5 ┆ dragon │ + └─────────────┴─────┴──────────┘ + """ + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.str_head(n)) + + def tail(self, n: int | IntoExprColumn) -> Expr: + """ + Return the last n characters of each string in a String Series. + + Parameters + ---------- + n + Length of the slice (integer or expression). Negative indexing is supported; + see note (2) below. + + Returns + ------- + Expr + Expression of data type :class:`String`. + + Notes + ----- + 1) The `n` input is defined in terms of the number of characters in the (UTF8) + string. A character is defined as a `Unicode scalar value`_. A single + character is represented by a single byte when working with ASCII text, and a + maximum of 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value + + 2) When the `n` input is negative, `tail` returns characters starting from the + `n`th from the beginning of the string. For example, if `n = -3`, then all + characters except the first three are returned. + + 3) If the length of the string has fewer than `n` characters, the full string is + returned. + + Examples + -------- + Return up to the last 5 characters: + + >>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]}) + >>> df.with_columns(pl.col("s").str.tail(5).alias("s_tail_5")) + shape: (4, 2) + ┌─────────────┬──────────┐ + │ s ┆ s_tail_5 │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════════════╪══════════╡ + │ pear ┆ pear │ + │ null ┆ null │ + │ papaya ┆ apaya │ + │ dragonfruit ┆ fruit │ + └─────────────┴──────────┘ + + Return characters determined by column `n`: + + >>> df = pl.DataFrame( + ... { + ... "s": ["pear", None, "papaya", "dragonfruit"], + ... "n": [3, 4, -2, -5], + ... } + ... ) + >>> df.with_columns(pl.col("s").str.tail("n").alias("s_tail_n")) + shape: (4, 3) + ┌─────────────┬─────┬──────────┐ + │ s ┆ n ┆ s_tail_n │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ str │ + ╞═════════════╪═════╪══════════╡ + │ pear ┆ 3 ┆ ear │ + │ null ┆ 4 ┆ null │ + │ papaya ┆ -2 ┆ paya │ + │ dragonfruit ┆ -5 ┆ nfruit │ + └─────────────┴─────┴──────────┘ + """ + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.str_tail(n)) + def explode(self) -> Expr: """ Returns a column with a separate row for every string character. diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index c14a32c82e31..c21217477dfb 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1658,6 +1658,126 @@ def slice( ] """ + def head(self, n: int | IntoExprColumn) -> Series: + """ + Return the first n characters of each string in a String Series. + + Parameters + ---------- + n + Length of the slice (integer or expression). Negative indexing is supported; + see note (2) below. + + Returns + ------- + Series + Series of data type :class:`String`. + + Notes + ----- + 1) The `n` input is defined in terms of the number of characters in the (UTF8) + string. A character is defined as a `Unicode scalar value`_. A single + character is represented by a single byte when working with ASCII text, and a + maximum of 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value + + 2) When `n` is negative, `head` returns characters up to the `n`th from the end + of the string. For example, if `n = -3`, then all characters except the last + three are returned. + + 3) If the length of the string has fewer than `n` characters, the full string is + returned. + + Examples + -------- + Return up to the first 5 characters. + + >>> s = pl.Series(["pear", None, "papaya", "dragonfruit"]) + >>> s.str.head(5) + shape: (4,) + Series: '' [str] + [ + "pear" + null + "papay" + "drago" + ] + + Return up to the 3rd character from the end. + + >>> s = pl.Series(["pear", None, "papaya", "dragonfruit"]) + >>> s.str.head(-3) + shape: (4,) + Series: '' [str] + [ + "p" + null + "pap" + "dragonfr" + ] + """ + + def tail(self, n: int | IntoExprColumn) -> Series: + """ + Return the last n characters of each string in a String Series. + + Parameters + ---------- + n + Length of the slice (integer or expression). Negative indexing is supported; + see note (2) below. + + Returns + ------- + Series + Series of data type :class:`String`. + + Notes + ----- + 1) The `n` input is defined in terms of the number of characters in the (UTF8) + string. A character is defined as a `Unicode scalar value`_. A single + character is represented by a single byte when working with ASCII text, and a + maximum of 4 bytes otherwise. + + .. _Unicode scalar value: https://www.unicode.org/glossary/#unicode_scalar_value + + 2) When `n` is negative, `tail` returns characters starting from the `n`th from + the beginning of the string. For example, if `n = -3`, then all characters + except the first three are returned. + + 3) If the length of the string has fewer than `n` characters, the full string is + returned. + + Examples + -------- + Return up to the last 5 characters: + + >>> s = pl.Series(["pear", None, "papaya", "dragonfruit"]) + >>> s.str.tail(5) + shape: (4,) + Series: '' [str] + [ + "pear" + null + "apaya" + "fruit" + ] + + Return from the 3rd character to the end: + + >>> s = pl.Series(["pear", None, "papaya", "dragonfruit"]) + >>> s.str.tail(-3) + shape: (4,) + Series: '' [str] + [ + "r" + null + "aya" + "gonfruit" + ] + """ + def explode(self) -> Series: """ Returns a column with a separate row for every string character. diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 45eb9ef616e5..e39d935fa569 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -439,16 +439,17 @@ impl PyExpr { fn gather_every(&self, n: usize, offset: usize) -> Self { self.inner.clone().gather_every(n, offset).into() } - fn tail(&self, n: usize) -> Self { - self.inner.clone().tail(Some(n)).into() + + fn slice(&self, offset: Self, length: Self) -> Self { + self.inner.clone().slice(offset.inner, length.inner).into() } fn head(&self, n: usize) -> Self { self.inner.clone().head(Some(n)).into() } - fn slice(&self, offset: Self, length: Self) -> Self { - self.inner.clone().slice(offset.inner, length.inner).into() + fn tail(&self, n: usize) -> Self { + self.inner.clone().tail(Some(n)).into() } fn append(&self, other: Self, upcast: bool) -> Self { diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index 5f870c204994..687643f71fa9 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -102,6 +102,14 @@ impl PyExpr { .into() } + fn str_head(&self, n: Self) -> Self { + self.inner.clone().str().head(n.inner).into() + } + + fn str_tail(&self, n: Self) -> Self { + self.inner.clone().str().tail(n.inner).into() + } + fn str_explode(&self) -> Self { self.inner.clone().str().explode().into() } diff --git a/py-polars/tests/unit/namespaces/string/test_string.py b/py-polars/tests/unit/namespaces/string/test_string.py index ce7e2e4a54b0..07765257a133 100644 --- a/py-polars/tests/unit/namespaces/string/test_string.py +++ b/py-polars/tests/unit/namespaces/string/test_string.py @@ -46,6 +46,128 @@ def test_str_slice_expr() -> None: df.select(pl.col("a").str.slice(0, -1)) +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + (["012345", "", None], 0, ["", "", None]), + (["012345", "", None], 2, ["01", "", None]), + (["012345", "", None], -2, ["0123", "", None]), + (["012345", "", None], 100, ["012345", "", None]), + (["012345", "", None], -100, ["", "", None]), + ], +) +def test_str_head(input: list[str], n: int, output: list[str]) -> None: + assert pl.Series(input).str.head(n).to_list() == output + + +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + ("你好世界", 0, ""), + ("你好世界", 2, "你好"), + ("你好世界", 999, "你好世界"), + ("你好世界", -1, "你好世"), + ("你好世界", -2, "你好"), + ("你好世界", -999, ""), + ], +) +def test_str_head_codepoints(input: str, n: int, output: str) -> None: + assert pl.Series([input]).str.head(n).to_list() == [output] + + +def test_str_head_expr() -> None: + s = "012345" + df = pl.DataFrame( + {"a": [s, s, s, s, s, s, "", None], "n": [0, 2, -2, 100, -100, None, 3, -2]} + ) + out = df.select( + n_expr=pl.col("a").str.head("n"), + n_pos2=pl.col("a").str.head(2), + n_neg2=pl.col("a").str.head(-2), + n_pos100=pl.col("a").str.head(100), + n_pos_neg100=pl.col("a").str.head(-100), + n_pos_0=pl.col("a").str.head(0), + str_lit=pl.col("a").str.head(pl.lit(2)), + lit_expr=pl.lit(s).str.head("n"), + lit_n=pl.lit(s).str.head(2), + ) + expected = pl.DataFrame( + { + "n_expr": ["", "01", "0123", "012345", "", None, "", None], + "n_pos2": ["01", "01", "01", "01", "01", "01", "", None], + "n_neg2": ["0123", "0123", "0123", "0123", "0123", "0123", "", None], + "n_pos100": [s, s, s, s, s, s, "", None], + "n_pos_neg100": ["", "", "", "", "", "", "", None], + "n_pos_0": ["", "", "", "", "", "", "", None], + "str_lit": ["01", "01", "01", "01", "01", "01", "", None], + "lit_expr": ["", "01", "0123", "012345", "", None, "012", "0123"], + "lit_n": ["01", "01", "01", "01", "01", "01", "01", "01"], + } + ) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + (["012345", "", None], 0, ["", "", None]), + (["012345", "", None], 2, ["45", "", None]), + (["012345", "", None], -2, ["2345", "", None]), + (["012345", "", None], 100, ["012345", "", None]), + (["012345", "", None], -100, ["", "", None]), + ], +) +def test_str_tail(input: list[str], n: int, output: list[str]) -> None: + assert pl.Series(input).str.tail(n).to_list() == output + + +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + ("你好世界", 0, ""), + ("你好世界", 2, "世界"), + ("你好世界", 999, "你好世界"), + ("你好世界", -1, "好世界"), + ("你好世界", -2, "世界"), + ("你好世界", -999, ""), + ], +) +def test_str_tail_codepoints(input: str, n: int, output: str) -> None: + assert pl.Series([input]).str.tail(n).to_list() == [output] + + +def test_str_tail_expr() -> None: + s = "012345" + df = pl.DataFrame( + {"a": [s, s, s, s, s, s, "", None], "n": [0, 2, -2, 100, -100, None, 3, -2]} + ) + out = df.select( + n_expr=pl.col("a").str.tail("n"), + n_pos2=pl.col("a").str.tail(2), + n_neg2=pl.col("a").str.tail(-2), + n_pos100=pl.col("a").str.tail(100), + n_pos_neg100=pl.col("a").str.tail(-100), + n_pos_0=pl.col("a").str.tail(0), + str_lit=pl.col("a").str.tail(pl.lit(2)), + lit_expr=pl.lit(s).str.tail("n"), + lit_n=pl.lit(s).str.tail(2), + ) + expected = pl.DataFrame( + { + "n_expr": ["", "45", "2345", "012345", "", None, "", None], + "n_pos2": ["45", "45", "45", "45", "45", "45", "", None], + "n_neg2": ["2345", "2345", "2345", "2345", "2345", "2345", "", None], + "n_pos100": [s, s, s, s, s, s, "", None], + "n_pos_neg100": ["", "", "", "", "", "", "", None], + "n_pos_0": ["", "", "", "", "", "", "", None], + "str_lit": ["45", "45", "45", "45", "45", "45", "", None], + "lit_expr": ["", "45", "2345", "012345", "", None, "345", "2345"], + "lit_n": ["45", "45", "45", "45", "45", "45", "45", "45"], + } + ) + assert_frame_equal(out, expected) + + def test_str_slice_multibyte() -> None: ref = "你好世界" s = pl.Series([ref])