Skip to content

Commit

Permalink
Implement $item() for DataFrame and Series (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher committed Mar 30, 2024
1 parent fb57507 commit 054be14
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 45 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
- New string method `$str$find()` (#985).
- New argument `n` in `$str$replace()` (#987).
- Method `$over()` gains an argument `mapping_strategy` (#984, #988).
- New method `$item()` for `DataFrame` and `Series` (#992).

### Bug fixes

Expand Down
64 changes: 64 additions & 0 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2225,3 +2225,67 @@ DataFrame_partition_by = function(

partitions
}


#' Return the element at the given row/column.
#'
#' If row and column location are not specified, the [DataFrame][DataFrame_class]
#' must have dimensions (1, 1).
#'
#' @param row Optional row index (0-indexed).
#' @param column Optional column index (0-indexed) or name.
#'
#' @return A value of length 1
#'
#' @examples
#' df = pl$DataFrame(a = c(1, 2, 3), b = c(4, 5, 6))
#'
#' df$select((pl$col("a") * pl$col("b"))$sum())$item()
#'
#' df$item(1, 1)
#'
#' df$item(2, "b")
DataFrame_item = function(row = NULL, column = NULL) {
uw = \(res) unwrap(res, "in $item():")

row_null = is.null(row)
col_null = is.null(column)

if (row_null && col_null) {
if (!identical(self$shape, c(1, 1))) {
Err_plain(
"Can only call $item() if the DataFrame is of shape (1, 1) or if explicit row/col values are provided."
) |> uw()
}
out = .pr$DataFrame$select_at_idx(self, 0) |>
uw() |>
as.vector()
return(out)
}

if ((!row_null && col_null) || (row_null && !col_null)) {
Err_plain("Cannot call `$item()` with only one of `row` or `column`.") |>
uw()
}

if (is.numeric(column)) {
column = self$columns[column + 1]
if (is.na(column)) {
Err_plain("`column` is out of bounds.") |>
uw()
}
} else if (is.character(column)) {
if (!column %in% self$columns) {
Err_plain("`column` does not exist.") |>
uw()
}
}

out = self$get_column(column)[row + 1]$to_r()
if (length(out) == 0) {
Err_plain("`row` is out of bounds.") |>
uw()
}

out
}
29 changes: 29 additions & 0 deletions R/series__series.R
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,32 @@ Series_to_lit = function() {
Series_n_unique = function() {
unwrap(.pr$Series$n_unique(self), "in $n_unique():")
}

#' Return the element at the given index
#'
#' @param index Index of the item to return.
#'
#' @return A value of length 1
#'
#' @examples
#' s1 = pl$Series(values = 1)
#'
#' s1$item()
#'
#' s2 = pl$Series(values = 9:7)
#'
#' s2$cum_sum()$item(-1)
Series_item = function(index = NULL) {
if (is.null(index)) {
if (self$len() != 1) {
Err_plain("Can only call $item() if the Series is of length 1.") |>
unwrap("in $item():")
}
index = 0
}
if (length(index) > 1 || !is.numeric(index)) {
Err_plain("`index` must be an integer of length 1.") |>
unwrap("in $item():")
}
self$gather(index)$to_r()
}
29 changes: 29 additions & 0 deletions man/DataFrame_item.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions man/Series_item.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

90 changes: 45 additions & 45 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@
[9] "equals" "estimated_size" "explode" "fill_nan"
[13] "fill_null" "filter" "first" "flags"
[17] "get_column" "get_columns" "glimpse" "group_by"
[21] "group_by_dynamic" "head" "height" "join"
[25] "join_asof" "last" "lazy" "limit"
[29] "max" "mean" "median" "melt"
[33] "min" "n_chunks" "null_count" "partition_by"
[37] "pivot" "print" "quantile" "rechunk"
[41] "rename" "reverse" "rolling" "sample"
[45] "schema" "select" "shape" "shift"
[49] "shift_and_fill" "slice" "sort" "std"
[53] "sum" "tail" "to_data_frame" "to_list"
[57] "to_series" "to_struct" "transpose" "unique"
[61] "unnest" "var" "width" "with_columns"
[65] "with_row_index" "write_csv" "write_json" "write_ndjson"
[69] "write_parquet"
[21] "group_by_dynamic" "head" "height" "item"
[25] "join" "join_asof" "last" "lazy"
[29] "limit" "max" "mean" "median"
[33] "melt" "min" "n_chunks" "null_count"
[37] "partition_by" "pivot" "print" "quantile"
[41] "rechunk" "rename" "reverse" "rolling"
[45] "sample" "schema" "select" "shape"
[49] "shift" "shift_and_fill" "slice" "sort"
[53] "std" "sum" "tail" "to_data_frame"
[57] "to_list" "to_series" "to_struct" "transpose"
[61] "unique" "unnest" "var" "width"
[65] "with_columns" "with_row_index" "write_csv" "write_json"
[69] "write_ndjson" "write_parquet"

---

Expand Down Expand Up @@ -654,38 +654,38 @@
[82] "is_infinite" "is_last_distinct" "is_nan"
[85] "is_not_nan" "is_not_null" "is_null"
[88] "is_numeric" "is_sorted" "is_unique"
[91] "kurtosis" "last" "len"
[94] "limit" "list" "log"
[97] "log10" "lower_bound" "lt"
[100] "lt_eq" "map_batches" "map_elements"
[103] "max" "mean" "median"
[106] "min" "mod" "mode"
[109] "mul" "n_unique" "name"
[112] "nan_max" "nan_min" "neq"
[115] "neq_missing" "not" "null_count"
[118] "or" "pct_change" "peak_max"
[121] "peak_min" "pow" "print"
[124] "product" "quantile" "rank"
[127] "rechunk" "reinterpret" "rename"
[130] "rep" "rep_extend" "repeat_by"
[133] "replace" "reshape" "reverse"
[136] "rle" "rle_id" "rolling_max"
[139] "rolling_mean" "rolling_median" "rolling_min"
[142] "rolling_quantile" "rolling_skew" "rolling_std"
[145] "rolling_sum" "rolling_var" "round"
[148] "sample" "search_sorted" "set_sorted"
[151] "shape" "shift" "shift_and_fill"
[154] "shrink_dtype" "shuffle" "sign"
[157] "sin" "sinh" "skew"
[160] "slice" "sort" "sort_by"
[163] "sqrt" "std" "str"
[166] "struct" "sub" "sum"
[169] "tail" "tan" "tanh"
[172] "to_frame" "to_list" "to_lit"
[175] "to_physical" "to_r" "to_struct"
[178] "to_vector" "top_k" "unique"
[181] "unique_counts" "upper_bound" "value_counts"
[184] "var" "xor"
[91] "item" "kurtosis" "last"
[94] "len" "limit" "list"
[97] "log" "log10" "lower_bound"
[100] "lt" "lt_eq" "map_batches"
[103] "map_elements" "max" "mean"
[106] "median" "min" "mod"
[109] "mode" "mul" "n_unique"
[112] "name" "nan_max" "nan_min"
[115] "neq" "neq_missing" "not"
[118] "null_count" "or" "pct_change"
[121] "peak_max" "peak_min" "pow"
[124] "print" "product" "quantile"
[127] "rank" "rechunk" "reinterpret"
[130] "rename" "rep" "rep_extend"
[133] "repeat_by" "replace" "reshape"
[136] "reverse" "rle" "rle_id"
[139] "rolling_max" "rolling_mean" "rolling_median"
[142] "rolling_min" "rolling_quantile" "rolling_skew"
[145] "rolling_std" "rolling_sum" "rolling_var"
[148] "round" "sample" "search_sorted"
[151] "set_sorted" "shape" "shift"
[154] "shift_and_fill" "shrink_dtype" "shuffle"
[157] "sign" "sin" "sinh"
[160] "skew" "slice" "sort"
[163] "sort_by" "sqrt" "std"
[166] "str" "struct" "sub"
[169] "sum" "tail" "tan"
[172] "tanh" "to_frame" "to_list"
[175] "to_lit" "to_physical" "to_r"
[178] "to_struct" "to_vector" "top_k"
[181] "unique" "unique_counts" "upper_bound"
[184] "value_counts" "var" "xor"

---

Expand Down
35 changes: 35 additions & 0 deletions tests/testthat/test-dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -1409,3 +1409,38 @@ test_that("partition_by", {
df$partition_by("col2", maintain_order = FALSE, include_key = FALSE, as_nested_list = TRUE)
)
})

test_that("$item() works", {
df = pl$DataFrame(a = c(1, 2, 3), b = c(4, 5, 6))

expect_equal(df$select((pl$col("a") * pl$col("b"))$sum())$item(), 32)
expect_equal(df$item(1, 1), 5)
expect_equal(df$item(2, "b"), 6)

# errors

expect_grepl_error(
df$item(1, 4),
"`column` is out of bounds."
)
expect_grepl_error(
df$item(1, "foo"),
"`column` does not exist."
)
expect_grepl_error(
df$item(4, 1),
"`row` is out of bounds."
)
expect_grepl_error(
df$item(),
"if the DataFrame is of shape"
)
expect_grepl_error(
df$item(1),
" with only one of `row` or "
)
expect_grepl_error(
df$item(column = 1),
" with only one of `row` or "
)
})
19 changes: 19 additions & 0 deletions tests/testthat/test-series.R
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,22 @@ test_that("Positional arguments deprecation", {
)
)
})

test_that("$item() works", {
expect_equal(pl$Series(values = 1)$item(), 1)
expect_equal(pl$Series(values = 3:1)$cum_sum()$item(-1), 6)

# errors
expect_grepl_error(
pl$Series(values = 1)$item(c(0, 0)),
"`index` must be an integer of length 1"
)
expect_grepl_error(
pl$Series(values = 1)$item("a"),
"`index` must be an integer of length 1"
)
expect_grepl_error(
pl$Series(values = 1:2)$item(),
"if the Series is of length 1"
)
})

0 comments on commit 054be14

Please sign in to comment.