Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: The Extract function should filter rows before selecting columns #547

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# polars (development version)

## What's changed

- The Extract function (`[`) for DataFrame can use columns not included in the
result for filtering (#547).
- The Extract function (`[`) for LazyFrame can filter rows with Expressions (#547).

# polars 0.11.0

## BREAKING CHANGES DUE TO RUST-POLARS UPDATE
Expand Down
95 changes: 53 additions & 42 deletions R/s3_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,34 @@
#'
#' `<Series>[i]` is equivalent to `pl$select(<Series>)[i, , drop = TRUE]`.
#' @rdname S3_extract
#' @param x A [DataFrame][DataFrame_class] or [LazyFrame][LazyFrame_class]
#' @param i Rows to select
#' @param j Columns to select, either by index or by name.
#' @param x A [DataFrame][DataFrame_class], [LazyFrame][LazyFrame_class], or [Series][Series_class]
#' @param i Rows to select. Integer vector, logical vector, or an [Expression][Expr_class].
#' @param j Columns to select. Integer vector, logical vector, character vector, or an [Expression][Expr_class].
#' For LazyFrames, only an Expression can be used.
#' @param drop Convert to a Polars Series if only one column is selected.
#' For LazyFrames, if the result has one column and `drop = TRUE`, an error will occur.
#' @seealso
#' [`<DataFrame>$select()`][DataFrame_select],
#' [`<LazyFrame>$select()`][LazyFrame_select],
#' [`<DataFrame>$filter()`][DataFrame_filter],
#' [`<LazyFrame>$filter()`][LazyFrame_filter]
#' @examples
#' df = pl$DataFrame(data.frame(a = 1:3, b = letters[1:3]))
#' lf = df$lazy()
#'
#' # Select a row
#' df[1, ]
#'
#' # If only `i` is specified, it is treated as `j`
#' # Select a column
#' df[1]
#'
#' # Select a column by name (and convert to a Series)
#' df[, "b"]
#' df[pl$col("a") >= 2, ]
#'
#' # Can use Expression for filtering and column selection
#' lf[pl$col("a") >= 2, pl$col("b")$alias("new"), drop = FALSE] |>
#' as.data.frame()
#' @export
`[.DataFrame` = function(x, i, j, drop = TRUE) {
uw = \(res) unwrap(res, "in `[` (Extract):")
Expand All @@ -32,7 +44,43 @@
drop = !missing(drop) && drop
}

# selecting `j` is usually faster, so we start here.
if (!missing(i) && !isTRUE(only_i)) {
# `i == NULL` means return 0 rows
i = i %||% 0

if (is.atomic(i) && is.vector(i)) {
if (inherits(x, "LazyFrame")) {
Err_plain("Row selection using vector is not supported for LazyFrames.") |> uw()
}

if (is.logical(i)) {
# nrow() not available for LazyFrame
if (inherits(x, "DataFrame") && length(i) != nrow(x)) {
stop(sprintf("`i` must be of length %s.", nrow(x)), call. = FALSE)
}
idx = i
} else if (is.integer(i) || (is.numeric(i) && all(i %% 1 == 0))) {
negative = any(i < 0)
if (isTRUE(negative)) {
if (any(i > 0)) {
Err_plain("Elements of `i` must be all postive or all negative.") |> uw()
}
idx = !seq_len(x$height) %in% abs(i)
} else {
if (any(diff(i) < 0)) {
Err_plain("Elements of `i` must be in increasing order.") |> uw()
}
idx = seq_len(x$height) %in% i
}
}
x = x$filter(pl$lit(idx))
} else if (identical(class(i), "Expr")) {
x = x$filter(i)
} else {
Err_plain("`i` must be an Expr or an atomic vector of class logical or integer.") |> uw()
}
}

if (!missing(j)) {
if (is.atomic(j) && is.vector(j)) {
if (is.logical(j)) {
Expand Down Expand Up @@ -67,43 +115,6 @@
}
}

if (!missing(i) && !isTRUE(only_i)) {
if (inherits(x, "LazyFrame")) {
Err_plain("Row selection using brackets is not supported for LazyFrames.") |> uw()
}

# `i == NULL` means return 0 rows
i = i %||% 0

if (is.atomic(i) && is.vector(i)) {
if (is.logical(i)) {
# nrow() not available for LazyFrame
if (inherits(x, "DataFrame") && length(i) != nrow(x)) {
stop(sprintf("`i` must be of length %s.", nrow(x)), call. = FALSE)
}
idx = i
} else if (is.integer(i) || (is.numeric(i) && all(i %% 1 == 0))) {
negative = any(i < 0)
if (isTRUE(negative)) {
if (any(i > 0)) {
Err_plain("Elements of `i` must be all postive or all negative.") |> uw()
}
idx = !seq_len(x$height) %in% abs(i)
} else {
if (any(diff(i) < 0)) {
Err_plain("Elements of `i` must be in increasing order.") |> uw()
}
idx = seq_len(x$height) %in% i
}
}
x = x$filter(pl$lit(idx))
} else if (identical(class(i), "Expr")) {
x = x$filter(i)
} else {
Err_plain("`i` must be an Expr or an atomic vector of class logical or integer.") |> uw()
}
}

if (drop && x$width == 1L) {
if (inherits(x, "LazyFrame")) {
Err_plain(
Expand Down
22 changes: 17 additions & 5 deletions man/S3_extract.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 11 additions & 1 deletion tests/testthat/test-s3_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ test_that("brackets", {
expect_equal(df["cyl"]$to_data_frame(), mtcars["cyl"], ignore_attr = TRUE)
expect_equal(df[1:3]$to_data_frame(), mtcars[1:3], ignore_attr = TRUE)
expect_equal(df[NULL, ]$to_data_frame(), mtcars[NULL, ], ignore_attr = TRUE)
expect_equal(df[pl$col("cyl") >= 8, ]$to_data_frame(), mtcars[mtcars$cyl >= 8, ], ignore_attr = TRUE)
expect_equal(
df[pl$col("cyl") >= 8, c("disp", "mpg")]$to_data_frame(),
mtcars[mtcars$cyl >= 8, c("disp", "mpg")],
ignore_attr = TRUE
)

df = pl$DataFrame(mtcars)
a = mtcars[-(1:2), -c(1, 3, 6, 9)]
Expand Down Expand Up @@ -239,6 +243,12 @@ test_that("brackets", {
b = mtcars[, c(1, 4, 2)]
expect_equal(a, b, ignore_attr = TRUE)

expect_equal(
lf[pl$col("cyl") >= 8, c("disp", "mpg")]$collect()$to_data_frame(),
mtcars[mtcars$cyl >= 8, c("disp", "mpg")],
ignore_attr = TRUE
)

# Not supported for lazy
expect_error(lf[1:3, ], "not supported")
expect_error(lf[, "cyl"], "not supported")
Expand Down
Loading