Skip to content

Commit

Permalink
Implement $rle() and $rle_id() (#648)
Browse files Browse the repository at this point in the history
Co-authored-by: eitsupi <ts1s1andn@gmail.com>
  • Loading branch information
etiennebacher and eitsupi authored Jan 3, 2024
1 parent ff0ee81 commit 3bf4f94
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 61 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- New methods `$str$reverse()`, `$str$contains_any()`, and `$str$replace_many()`
(#641).
- New methods `$rle()` and `$rle_id()` (#648).

## polars 0.12.0

Expand Down
29 changes: 29 additions & 0 deletions R/expr__expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -3633,3 +3633,32 @@ Expr_replace = function(old, new, default = NULL, return_dtype = NULL) {
.pr$Expr$replace(self, old, new, default, return_dtype) |>
unwrap("in $replace():")
}

#' Get the lengths of runs of identical values
#'
#' @return Expr
#'
#' @examples
#' df = pl$DataFrame(s = c(1, 1, 2, 1, NA, 1, 3, 3))
#' df$select(pl$col("s")$rle())$unnest("s")
Expr_rle = function() {
.pr$Expr$rle(self) |>
unwrap("in $rle():")
}

#' Map values to run IDs
#'
#' Similar to $rle(), but it maps each value to an ID corresponding to the run
#' into which it falls. This is especially useful when you want to define groups
#' by runs of identical values rather than the values themselves. Note that
#' the ID is 0-indexed.
#'
#' @return Expr
#'
#' @examples
#' df = pl$DataFrame(a = c(1, 2, 1, 1, 1, 4))
#' df$with_columns(a_r = pl$col("a")$rle_id())
Expr_rle_id = function() {
.pr$Expr$rle_id(self) |>
unwrap("in $rle_id():")
}
4 changes: 4 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,10 @@ RPolarsExpr$peak_max <- function() .Call(wrap__RPolarsExpr__peak_max, self)

RPolarsExpr$replace <- function(old, new, default, return_dtype) .Call(wrap__RPolarsExpr__replace, self, old, new, default, return_dtype)

RPolarsExpr$rle <- function() .Call(wrap__RPolarsExpr__rle, self)

RPolarsExpr$rle_id <- function() .Call(wrap__RPolarsExpr__rle_id, self)

RPolarsExpr$list_lengths <- function() .Call(wrap__RPolarsExpr__list_lengths, self)

RPolarsExpr$list_contains <- function(other) .Call(wrap__RPolarsExpr__list_contains, self, other)
Expand Down
18 changes: 18 additions & 0 deletions man/Expr_rle.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions man/Expr_rle_id.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ features = [
"reinterpret",
"repeat_by",
"replace",
"rle",
"rolling_window",
"round_series",
"row_hash",
Expand Down
8 changes: 8 additions & 0 deletions src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,14 @@ impl RPolarsExpr {
.into())
}

pub fn rle(&self) -> RResult<Self> {
Ok(self.0.clone().rle().into())
}

pub fn rle_id(&self) -> RResult<Self> {
Ok(self.0.clone().rle_id().into())
}

//arr/list methods

fn list_lengths(&self) -> Self {
Expand Down
117 changes: 59 additions & 58 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,22 +237,22 @@
[124] "rank" "rechunk" "reinterpret"
[127] "rep" "rep_extend" "repeat_by"
[130] "replace" "reshape" "reverse"
[133] "rolling" "rolling_max" "rolling_mean"
[136] "rolling_median" "rolling_min" "rolling_quantile"
[139] "rolling_skew" "rolling_std" "rolling_sum"
[142] "rolling_var" "round" "sample"
[145] "search_sorted" "set_sorted" "shift"
[148] "shift_and_fill" "shrink_dtype" "shuffle"
[151] "sign" "sin" "sinh"
[154] "skew" "slice" "sort"
[157] "sort_by" "sqrt" "std"
[160] "str" "struct" "sub"
[163] "sum" "tail" "tan"
[166] "tanh" "to_physical" "to_r"
[169] "to_series" "to_struct" "top_k"
[172] "unique" "unique_counts" "upper_bound"
[175] "value_counts" "var" "where"
[178] "xor"
[133] "rle" "rle_id" "rolling"
[136] "rolling_max" "rolling_mean" "rolling_median"
[139] "rolling_min" "rolling_quantile" "rolling_skew"
[142] "rolling_std" "rolling_sum" "rolling_var"
[145] "round" "sample" "search_sorted"
[148] "set_sorted" "shift" "shift_and_fill"
[151] "shrink_dtype" "shuffle" "sign"
[154] "sin" "sinh" "skew"
[157] "slice" "sort" "sort_by"
[160] "sqrt" "std" "str"
[163] "struct" "sub" "sum"
[166] "tail" "tan" "tanh"
[169] "to_physical" "to_r" "to_series"
[172] "to_struct" "top_k" "unique"
[175] "unique_counts" "upper_bound" "value_counts"
[178] "var" "where" "xor"

---

Expand Down Expand Up @@ -361,48 +361,49 @@
[199] "rem" "rep"
[201] "repeat_by" "replace"
[203] "reshape" "reverse"
[205] "rolling" "rolling_corr"
[207] "rolling_cov" "rolling_max"
[209] "rolling_mean" "rolling_median"
[211] "rolling_min" "rolling_quantile"
[213] "rolling_skew" "rolling_std"
[215] "rolling_sum" "rolling_var"
[217] "round" "sample_frac"
[219] "sample_n" "search_sorted"
[221] "shift" "shift_and_fill"
[223] "shrink_dtype" "shuffle"
[225] "sign" "sin"
[227] "sinh" "skew"
[229] "slice" "sort"
[231] "sort_by" "std"
[233] "str_base64_decode" "str_base64_encode"
[235] "str_concat" "str_contains"
[237] "str_contains_any" "str_count_matches"
[239] "str_ends_with" "str_explode"
[241] "str_extract" "str_extract_all"
[243] "str_hex_decode" "str_hex_encode"
[245] "str_json_decode" "str_json_path_match"
[247] "str_len_bytes" "str_len_chars"
[249] "str_pad_end" "str_pad_start"
[251] "str_parse_int" "str_replace"
[253] "str_replace_all" "str_replace_many"
[255] "str_reverse" "str_slice"
[257] "str_split" "str_split_exact"
[259] "str_splitn" "str_starts_with"
[261] "str_strip_chars" "str_strip_chars_end"
[263] "str_strip_chars_start" "str_to_date"
[265] "str_to_datetime" "str_to_lowercase"
[267] "str_to_time" "str_to_titlecase"
[269] "str_to_uppercase" "str_zfill"
[271] "struct_field_by_name" "struct_rename_fields"
[273] "sub" "sum"
[275] "tail" "tan"
[277] "tanh" "timestamp"
[279] "to_physical" "top_k"
[281] "unique" "unique_counts"
[283] "unique_stable" "upper_bound"
[285] "value_counts" "var"
[287] "xor"
[205] "rle" "rle_id"
[207] "rolling" "rolling_corr"
[209] "rolling_cov" "rolling_max"
[211] "rolling_mean" "rolling_median"
[213] "rolling_min" "rolling_quantile"
[215] "rolling_skew" "rolling_std"
[217] "rolling_sum" "rolling_var"
[219] "round" "sample_frac"
[221] "sample_n" "search_sorted"
[223] "shift" "shift_and_fill"
[225] "shrink_dtype" "shuffle"
[227] "sign" "sin"
[229] "sinh" "skew"
[231] "slice" "sort"
[233] "sort_by" "std"
[235] "str_base64_decode" "str_base64_encode"
[237] "str_concat" "str_contains"
[239] "str_contains_any" "str_count_matches"
[241] "str_ends_with" "str_explode"
[243] "str_extract" "str_extract_all"
[245] "str_hex_decode" "str_hex_encode"
[247] "str_json_decode" "str_json_path_match"
[249] "str_len_bytes" "str_len_chars"
[251] "str_pad_end" "str_pad_start"
[253] "str_parse_int" "str_replace"
[255] "str_replace_all" "str_replace_many"
[257] "str_reverse" "str_slice"
[259] "str_split" "str_split_exact"
[261] "str_splitn" "str_starts_with"
[263] "str_strip_chars" "str_strip_chars_end"
[265] "str_strip_chars_start" "str_to_date"
[267] "str_to_datetime" "str_to_lowercase"
[269] "str_to_time" "str_to_titlecase"
[271] "str_to_uppercase" "str_zfill"
[273] "struct_field_by_name" "struct_rename_fields"
[275] "sub" "sum"
[277] "tail" "tan"
[279] "tanh" "timestamp"
[281] "to_physical" "top_k"
[283] "unique" "unique_counts"
[285] "unique_stable" "upper_bound"
[287] "value_counts" "var"
[289] "xor"

# public and private methods of each class When

Expand Down
28 changes: 25 additions & 3 deletions tests/testthat/test-expr_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -2626,11 +2626,33 @@ test_that("replace works", {
expect_equal(
df$select(
replaced = pl$col("a")$replace(
old=pl$col("a")$max(),
new=pl$col("b")$sum(),
default=pl$col("b"),
old = pl$col("a")$max(),
new = pl$col("b")$sum(),
default = pl$col("b"),
)
)$to_list(),
list(replaced = c(1.5, 2.5, 5, 10))
)
})

test_that("rle works", {
df = pl$DataFrame(s = c(1, 1, 2, 1, NA, 1, 3, 3))
expect_equal(
df$select(pl$col("s")$rle())$unnest("s")$to_data_frame(),
data.frame(
lengths = c(2, 1, 1, 1, 1, 2),
values = c(1, 2, 1, NA, 1, 3)
)
)
})

test_that("rle_id works", {
df = pl$DataFrame(s = c(1, 1, 2, 1, NA, 1, 3, 3))
expect_equal(
df$with_columns(id = pl$col("s")$rle_id())$to_data_frame(),
data.frame(
s = c(1, 1, 2, 1, NA, 1, 3, 3),
id = c(0, 0, 1, 2, 3, 4, 5, 5)
)
)
})

0 comments on commit 3bf4f94

Please sign in to comment.