Skip to content

Commit

Permalink
Add set_ operations for list columns (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher committed Jan 19, 2024
1 parent 7e6830a commit 1c7feb5
Show file tree
Hide file tree
Showing 13 changed files with 399 additions and 76 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
- New methods `$list$any()` and `$list$all()` (#709).
- New function `pl$from_epoch()` to convert a Unix timestamp to a date(time)
variable (#708).
- New methods for the `list` subnamespace: `$set_union()`, `$set_intersection()`,
`$set_difference()`, `$set_symmetric_difference()` (#712).

## polars 0.12.2

Expand Down
83 changes: 83 additions & 0 deletions R/expr__list.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,86 @@ ExprList_all = function() .pr$Expr$list_all(self)
#' )
#' df$with_columns(any = pl$col("a")$list$any())
ExprList_any = function() .pr$Expr$list_any(self)

#' Get the union of two list variables
#'
#' @param other Other list variable. Can be an Expr or something coercible to an
#' Expr.
#'
#' @details
#' Note that the datatypes inside the list must have a common supertype. For
#' example, the first column can be `list[i32]` and the second one can be
#' `list[i8]` because it can be cast to `list[i32]`. However, the second column
#' cannot be e.g `list[f32]`.
#'
#' @return Expr
#' @examples
#' df = pl$DataFrame(
#' a = list(1:3, NA_integer_, c(NA_integer_, 3L), 5:7),
#' b = list(2:4, 3L, c(3L, 4L, NA_integer_), c(6L, 8L))
#' )
#'
#' df$with_columns(union = pl$col("a")$list$set_union("b"))
ExprList_set_union = function(other) {
.pr$Expr$list_set_operation(self, other, "union") |>
unwrap("in $list$set_union():")
}

#' Get the intersection of two list variables
#'
#' @inherit ExprList_set_union params details return
#'
#' @examples
#' df = pl$DataFrame(
#' a = list(1:3, NA_integer_, c(NA_integer_, 3L), 5:7),
#' b = list(2:4, 3L, c(3L, 4L, NA_integer_), c(6L, 8L))
#' )
#'
#' df$with_columns(intersection = pl$col("a")$list$set_intersection("b"))
ExprList_set_intersection = function(other) {
.pr$Expr$list_set_operation(self, other, "intersection") |>
unwrap("in $list$set_intersection():")
}

#' Get the difference of two list variables
#'
#' This returns the "asymmetric difference", meaning only the elements of the
#' first list that are not in the second list. To get all elements that are in
#' only one of the two lists, use
#' [`$set_symmetric_difference()`][ExprList_set_symmetric_difference].
#'
#' @inherit ExprList_set_union params details return
#'
#' @examples
#' df = pl$DataFrame(
#' a = list(1:3, NA_integer_, c(NA_integer_, 3L), 5:7),
#' b = list(2:4, 3L, c(3L, 4L, NA_integer_), c(6L, 8L))
#' )
#'
#' df$with_columns(difference = pl$col("a")$list$set_difference("b"))
ExprList_set_difference = function(other) {
.pr$Expr$list_set_operation(self, other, "difference") |>
unwrap("in $list$set_difference():")
}

#' Get the symmetric difference of two list variables
#'
#' This returns all elements that are in only one of the two lists. To get only
#' elements that are in the first list but not in the second one, use
#' [`$set_difference()`][ExprList_set_difference].
#'
#' @inherit ExprList_set_union params details return
#'
#' @examples
#' df = pl$DataFrame(
#' a = list(1:3, NA_integer_, c(NA_integer_, 3L), 5:7),
#' b = list(2:4, 3L, c(3L, 4L, NA_integer_), c(6L, 8L))
#' )
#'
#' df$with_columns(
#' symmetric_difference = pl$col("a")$list$set_symmetric_difference("b")
#' )
ExprList_set_symmetric_difference = function(other) {
.pr$Expr$list_set_operation(self, other, "symmetric_difference") |>
unwrap("in $list$set_symmetric_difference():")
}
2 changes: 2 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ RPolarsExpr$list_all <- function() .Call(wrap__RPolarsExpr__list_all, self)

RPolarsExpr$list_any <- function() .Call(wrap__RPolarsExpr__list_any, self)

RPolarsExpr$list_set_operation <- function(other, operation) .Call(wrap__RPolarsExpr__list_set_operation, self, other, operation)

RPolarsExpr$dt_truncate <- function(every, offset) .Call(wrap__RPolarsExpr__dt_truncate, self, every, offset)

RPolarsExpr$dt_round <- function(every, offset) .Call(wrap__RPolarsExpr__dt_round, self, every, offset)
Expand Down
35 changes: 35 additions & 0 deletions man/ExprList_set_difference.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions man/ExprList_set_intersection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions man/ExprList_set_symmetric_difference.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions man/ExprList_set_union.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ features = [
"list_any_all",
"list_eval",
"list_gather",
"list_sets",
"list_to_struct",
"log",
"meta",
Expand Down
15 changes: 14 additions & 1 deletion src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::CONFIG;
use extendr_api::{extendr, prelude::*, rprintln, Deref, DerefMut, Rinternals};
use pl::PolarsError as pl_error;
use pl::{
BinaryNameSpaceImpl, Duration, DurationMethods, IntoSeries, RollingGroupOptions,
BinaryNameSpaceImpl, Duration, DurationMethods, IntoSeries, RollingGroupOptions, SetOperation,
StringNameSpaceImpl, TemporalMethods,
};
use polars::lazy::dsl;
Expand Down Expand Up @@ -1168,6 +1168,19 @@ impl RPolarsExpr {
self.0.clone().list().any().into()
}

fn list_set_operation(&self, other: Robj, operation: Robj) -> RResult<Self> {
let other = robj_to!(PLExprCol, other)?;
let operation = robj_to!(SetOperation, operation)?;
let e = self.0.clone().list();
Ok(match operation {
SetOperation::Intersection => e.set_intersection(other),
SetOperation::Difference => e.set_difference(other),
SetOperation::Union => e.union(other),
SetOperation::SymmetricDifference => e.set_symmetric_difference(other),
}
.into())
}

//datetime methods

pub fn dt_truncate(&self, every: Robj, offset: Robj) -> RResult<Self> {
Expand Down
13 changes: 13 additions & 0 deletions src/rust/src/rdatatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,19 @@ pub fn robj_to_closed_window(robj: Robj) -> RResult<pl::ClosedWindow> {
}
}

pub fn robj_to_set_operation(robj: Robj) -> RResult<pl::SetOperation> {
use pl::SetOperation as SO;
match robj_to_rchoice(robj)?.as_str() {
"union" => Ok(SO::Union),
"intersection" => Ok(SO::Intersection),
"difference" => Ok(SO::Difference),
"symmetric_difference" => Ok(SO::SymmetricDifference),
s => rerr().bad_val(format!(
"SetOperation choice ['{s}'] should be one of 'union', 'intersection', 'difference', 'symmetric_difference'"
)),
}
}

pub fn robj_to_label(robj: Robj) -> RResult<pl::Label> {
use pl::Label;
match robj_to_rchoice(robj)?.as_str() {
Expand Down
3 changes: 3 additions & 0 deletions src/rust/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,9 @@ macro_rules! robj_to_inner {
(ClosedWindow, $a:ident) => {
$crate::rdatatype::robj_to_closed_window($a)
};
(SetOperation, $a:ident) => {
$crate::rdatatype::robj_to_set_operation($a)
};
(Label, $a:ident) => {
$crate::rdatatype::robj_to_label($a)
};
Expand Down
Loading

0 comments on commit 1c7feb5

Please sign in to comment.