Skip to content

Commit

Permalink
Implement $str$find() (#985)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher committed Mar 30, 2024
1 parent 1f7c927 commit 5757bde
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 34 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
- Export the `Duration` datatype (#955).
- New functions `pl$int_range()` and `pl$int_ranges()` (#968).
- New string method `$str$extract_groups()` (#979).
- New string method `$str$find()` (#985).

### Bug fixes

Expand Down
43 changes: 37 additions & 6 deletions R/expr__string.R
Original file line number Diff line number Diff line change
Expand Up @@ -429,24 +429,31 @@ ExprStr_pad_start = function(width, fillchar = " ") {
}


# TODO: Add ExprStr_find to seealso
#' Check if string contains a substring that matches a pattern
#'
#' @details To modify regular expression behaviour (such as case-sensitivity) with flags,
#' use the inline `(?iLmsuxU)` syntax. See the regex crate’s section on
#' [grouping and flags](https://docs.rs/regex/latest/regex/#grouping-and-flags)
#' @details To modify regular expression behaviour (such as case-sensitivity)
#' with flags, use the inline `(?iLmsuxU)` syntax. See the regex crate’s section
#' on [grouping and flags](https://docs.rs/regex/latest/regex/#grouping-and-flags)
#' for additional information about the use of inline expression modifiers.
#'
#' @param pattern A character or something can be coerced to a string [Expr][Expr_class]
#' of a valid regex pattern, compatible with the [regex crate](https://docs.rs/regex/latest/regex/).
#' @param ... Ignored.
#' @param literal Logical. If `TRUE` (default), treat `pattern` as a literal string,
#' not as a regular expression.
#' @param strict Logical. If `TRUE` (default), raise an error if the underlying pattern is
#' not a valid regex, otherwise mask out with a null value.
#'
#' @return [Expr][Expr_class] of Boolean data type
#' @seealso
#' - [`<Expr>$str$start_with()`][ExprStr_starts_with]: Check if string values start with a substring.
#' - [`<Expr>$str$ends_with()`][ExprStr_ends_with]: Check if string values end with a substring.
#' - [`$str$start_with()`][ExprStr_starts_with]: Check if string values
#' start with a substring.
#' - [`$str$ends_with()`][ExprStr_ends_with]: Check if string values end
#' with a substring.
#' - [`$str$find()`][ExprStr_find]: Return the index position of the first
#' substring matching a pattern.
#'
#'
#' @examples
#' # The inline `(?i)` syntax example
#' pl$DataFrame(s = c("AAA", "aAa", "aaa"))$with_columns(
Expand Down Expand Up @@ -966,3 +973,27 @@ ExprStr_extract_groups = function(pattern) {
.pr$Expr$str_extract_groups(self, pattern) |>
unwrap("in str$extract_groups():")
}

#' Return the index position of the first substring matching a pattern
#'
#' @inherit ExprStr_contains params details
#'
#' @return An Expr of data type UInt32
#'
#' @seealso
#' - [`$str$start_with()`][ExprStr_starts_with]: Check if string values
#' start with a substring.
#' - [`$str$ends_with()`][ExprStr_ends_with]: Check if string values end
#' with a substring.
#' - [`$str$contains()`][ExprStr_contains]: Check if string contains a substring
#' that matches a pattern.
#'
#' @examples
#' pl$DataFrame(s = c("AAA", "aAa", "aaa"))$with_columns(
#' default_match = pl$col("s")$str$find("Aa"),
#' insensitive_match = pl$col("s")$str$find("(?i)Aa")
#' )
ExprStr_find = function(pattern, ..., literal = FALSE, strict = TRUE) {
.pr$Expr$str_find(self, pattern, literal, strict) |>
unwrap("in str$find():")
}
2 changes: 2 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,8 @@ RPolarsExpr$str_contains_any <- function(patterns, ascii_case_insensitive) .Call

RPolarsExpr$str_replace_many <- function(patterns, replace_with, ascii_case_insensitive) .Call(wrap__RPolarsExpr__str_replace_many, self, patterns, replace_with, ascii_case_insensitive)

RPolarsExpr$str_find <- function(pat, literal, strict) .Call(wrap__RPolarsExpr__str_find, self, pat, literal, strict)

RPolarsExpr$bin_contains <- function(lit) .Call(wrap__RPolarsExpr__bin_contains, self, lit)

RPolarsExpr$bin_starts_with <- function(sub) .Call(wrap__RPolarsExpr__bin_starts_with, self, sub)
Expand Down
14 changes: 9 additions & 5 deletions man/ExprStr_contains.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions man/ExprStr_find.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2349,6 +2349,15 @@ impl RPolarsExpr {
.into())
}

pub fn str_find(&self, pat: Robj, literal: Robj, strict: Robj) -> RResult<Self> {
let pat = robj_to!(PLExpr, pat)?;
let literal = robj_to!(Option, bool, literal)?;
let strict = robj_to!(bool, strict)?;
match literal {
Some(true) => Ok(self.0.clone().str().find_literal(pat).into()),
_ => Ok(self.0.clone().str().find(pat, strict).into()),
}
}
//binary methods
pub fn bin_contains(&self, lit: Robj) -> RResult<Self> {
Ok(self
Expand Down
47 changes: 24 additions & 23 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,29 +406,30 @@
[267] "str_contains_any" "str_count_matches"
[269] "str_ends_with" "str_explode"
[271] "str_extract" "str_extract_all"
[273] "str_extract_groups" "str_hex_decode"
[275] "str_hex_encode" "str_json_decode"
[277] "str_json_path_match" "str_len_bytes"
[279] "str_len_chars" "str_pad_end"
[281] "str_pad_start" "str_parse_int"
[283] "str_replace" "str_replace_all"
[285] "str_replace_many" "str_reverse"
[287] "str_slice" "str_split"
[289] "str_split_exact" "str_splitn"
[291] "str_starts_with" "str_strip_chars"
[293] "str_strip_chars_end" "str_strip_chars_start"
[295] "str_to_date" "str_to_datetime"
[297] "str_to_lowercase" "str_to_time"
[299] "str_to_titlecase" "str_to_uppercase"
[301] "str_zfill" "struct_field_by_name"
[303] "struct_rename_fields" "sub"
[305] "sum" "tail"
[307] "tan" "tanh"
[309] "timestamp" "to_physical"
[311] "top_k" "unique"
[313] "unique_counts" "unique_stable"
[315] "upper_bound" "value_counts"
[317] "var" "xor"
[273] "str_extract_groups" "str_find"
[275] "str_hex_decode" "str_hex_encode"
[277] "str_json_decode" "str_json_path_match"
[279] "str_len_bytes" "str_len_chars"
[281] "str_pad_end" "str_pad_start"
[283] "str_parse_int" "str_replace"
[285] "str_replace_all" "str_replace_many"
[287] "str_reverse" "str_slice"
[289] "str_split" "str_split_exact"
[291] "str_splitn" "str_starts_with"
[293] "str_strip_chars" "str_strip_chars_end"
[295] "str_strip_chars_start" "str_to_date"
[297] "str_to_datetime" "str_to_lowercase"
[299] "str_to_time" "str_to_titlecase"
[301] "str_to_uppercase" "str_zfill"
[303] "struct_field_by_name" "struct_rename_fields"
[305] "sub" "sum"
[307] "tail" "tan"
[309] "tanh" "timestamp"
[311] "to_physical" "top_k"
[313] "unique" "unique_counts"
[315] "unique_stable" "upper_bound"
[317] "value_counts" "var"
[319] "xor"

# public and private methods of each class When

Expand Down
35 changes: 35 additions & 0 deletions tests/testthat/test-expr_string.R
Original file line number Diff line number Diff line change
Expand Up @@ -832,3 +832,38 @@ test_that("str$extract_groups() works", {
list(url = NULL)
)
})

test_that("str$find() works", {
test = pl$DataFrame(s = c("AAA", "aAa", "aaa", "(?i)Aa"))

expect_identical(
test$select(
default = pl$col("s")$str$find("Aa"),
insensitive = pl$col("s")$str$find("(?i)Aa")
)$to_list(),
list(default = c(NA, 1, NA, 4), insensitive = c(0, 0, 0, 4))
)

# arg "literal" works
expect_identical(
test$select(
lit = pl$col("s")$str$find("(?i)Aa", literal = TRUE)
)$to_list(),
list(lit = c(NA, NA, NA, 0))
)

# arg "strict" works
expect_grepl_error(
test$select(lit = pl$col("s")$str$find("(?iAa")),
"unrecognized flag"
)

expect_silent(
test$select(lit = pl$col("s")$str$find("(?iAa", strict = FALSE))
)

# combining "literal" and "strict"
expect_silent(
test$select(lit = pl$col("s")$str$find("(?iAa", strict = TRUE, literal = TRUE))
)
})

0 comments on commit 5757bde

Please sign in to comment.