Skip to content

Commit

Permalink
Merge pull request #160 from DavisVaughan/summary-group
Browse files Browse the repository at this point in the history
Summary group
  • Loading branch information
DavisVaughan committed May 10, 2019
2 parents 0d6ed9b + cc70750 commit 2a2985e
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 55 deletions.
5 changes: 3 additions & 2 deletions NAMESPACE
Expand Up @@ -18,8 +18,6 @@ S3method("dim_names<-",vctrs_rray)
S3method("dimnames<-",vctrs_rray)
S3method("names<-",vctrs_rray)
S3method("|",vctrs_rray)
S3method(all,vctrs_rray)
S3method(any,vctrs_rray)
S3method(aperm,vctrs_rray)
S3method(as.array,vctrs_rray)
S3method(as.double,vctrs_rray)
Expand Down Expand Up @@ -140,6 +138,8 @@ S3method(vec_type2.vctrs_rray_lgl,vctrs_rray_dbl)
S3method(vec_type2.vctrs_rray_lgl,vctrs_rray_int)
S3method(vec_type2.vctrs_rray_lgl,vctrs_rray_lgl)
S3method(vec_type2.vctrs_rray_lgl,vctrs_unspecified)
S3method(xtfrm,vctrs_rray)
S3method(xtfrm,vctrs_rray_lgl)
export("%b*%")
export("%b+%")
export("%b-%")
Expand Down Expand Up @@ -334,6 +334,7 @@ importFrom(vctrs,vec_duplicate_id)
importFrom(vctrs,vec_math)
importFrom(vctrs,vec_math_base)
importFrom(vctrs,vec_na)
importFrom(vctrs,vec_proxy_compare)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_restore)
Expand Down
1 change: 1 addition & 0 deletions R/aaa.R
Expand Up @@ -22,6 +22,7 @@
#' @importFrom vctrs new_rcrd
#' @importFrom vctrs field
#' @importFrom vctrs vec_split
#' @importFrom vctrs vec_proxy_compare
#'
#' @importFrom vctrs vec_ptype_full
#' @importFrom vctrs vec_ptype_abbr
Expand Down
7 changes: 7 additions & 0 deletions R/compat-vctrs-math.R
Expand Up @@ -60,6 +60,13 @@ rray_math_unary_op_switch <- function(fun) {
"is.infinite" = rray_is_infinite,
"is.finite" = rray_is_finite,

# summary
"all" = rray_all_vctrs_wrapper,
"any" = rray_any_vctrs_wrapper,
"range" = rray_range_vctrs_wrapper,
"prod" = rray_prod_vctrs_wrapper,
"sum" = rray_sum_vctrs_wrapper,

glubort("Unary math function not known: {fun}.")
)
}
Expand Down
41 changes: 0 additions & 41 deletions R/logical.R
Expand Up @@ -11,9 +11,6 @@
#' @param ... A single rray. An error is currently thrown if more than one
#' input is passed here.
#'
#' @param na.rm Should `NA` values be removed? Currently only `FALSE` is
#' allowed.
#'
#' @param axes An integer vector specifying the axes to reduce over.
#' `1` reduces the number of rows to `1`, performing the reduction along the
#' way. `2` does the same, but with the columns, and so on for higher
Expand Down Expand Up @@ -117,25 +114,6 @@ rray_logical_not <- function(x) {

# ------------------------------------------------------------------------------

# `any()` should always be on a flattened version of the input to maintain
# backwards compat with base R. `rray_any()` should only handle 1 input
# but should be able to look along an axis.
# TODO - Think more about this.

#' @rdname rray-logical
#' @export
`any.vctrs_rray` <- function(..., na.rm = FALSE) {

if (!identical(na.rm, FALSE)) {
abort("`na.rm` currently must be `FALSE` in `any()` for rrays.")
}

x <- map(list2(...), as.vector)
x <- vec_c(!!! x)

vec_math_base("any", x)
}

#' @rdname rray-logical
#' @export
rray_any <- function(x, axes = NULL) {
Expand All @@ -162,25 +140,6 @@ rray_any <- function(x, axes = NULL) {

# ------------------------------------------------------------------------------

# `all()` should always be on a flattened version of the input to maintain
# backwards compat with base R. `rray_all()` should only handle 1 input
# but should be able to look along an axis.
# TODO - Think more about this.

#' @rdname rray-logical
#' @export
`all.vctrs_rray` <- function(..., na.rm = FALSE) {

if (!identical(na.rm, FALSE)) {
abort("`na.rm` currently must be `FALSE` in `all()` for rrays.")
}

x <- map(list2(...), as.vector)
x <- vec_c(!!! x)

vec_math_base("all", x)
}

#' @rdname rray-logical
#' @export
rray_all <- function(x, axes = NULL) {
Expand Down
58 changes: 58 additions & 0 deletions R/summary-group.R
@@ -0,0 +1,58 @@
# Summary generics are:
# - all(), any()
# - sum(), prod()
# - min(), max()
# - range()

# All of them dispatch on the first argument, as described in `?Summary`

# ------------------------------------------------------------------------------

# - vctrs:::min.vctrs_vctr() does the right thing because of `xtfrm.vctrs_vctr()`
# - vctrs:::max.vctrs_vctr() does the right thing because of `xtfrm.vctrs_vctr()`

# However, vctrs only automatically implements `xtfrm()` for integer and double
# arrays, so logical ones need special treatment. Because of this, just
# implement a simple xtfrm() method that calls `vec_proxy_compare()` like vctrs

#' @export
xtfrm.vctrs_rray <- function(x) {
vec_proxy_compare(x)
}

#' @export
xtfrm.vctrs_rray_lgl <- function(x) {
rray_cast_inner(vec_proxy_compare(x), integer())
}

# ------------------------------------------------------------------------------

# This is a base R compatible version of `all()` and `any()`.
# It is used in vec_math() dispatch

# Note that `vctrs:::Summary.vctrs_vctr()` is how this is passed through,
# and `na.rm = TRUE` no matter what there!

rray_all_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("all", vec_data(x), na.rm = na.rm)
}

rray_any_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("any", vec_data(x), na.rm = na.rm)
}

# ------------------------------------------------------------------------------

rray_range_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("range", vec_data(x), na.rm = na.rm)
}

# ------------------------------------------------------------------------------

rray_prod_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("prod", vec_data(x), na.rm = na.rm)
}

rray_sum_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("sum", vec_data(x), na.rm = na.rm)
}
15 changes: 3 additions & 12 deletions man/rray-logical.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

166 changes: 166 additions & 0 deletions tests/testthat/test-summary-group.R
@@ -0,0 +1,166 @@
# ------------------------------------------------------------------------------
context("test-xtfrm")

test_that("xtfrm() returns proxy objects", {
x <- rray(1:2, c(2, 1))
expect_equal(xtfrm(x), vec_data(x))

x <- rray(c(1, 2), c(2, 1))
expect_equal(xtfrm(x), vec_data(x))
})

test_that("xtfrm() works for 3D", {
x <- rray(1:6, c(2, 1, 3))
expect_equal(xtfrm(x), vec_data(x))
})

test_that("xtfrm() for logicals returns integers", {
x <- rray(c(TRUE, FALSE), c(2, 1))
expect_equal(xtfrm(x), new_matrix(c(1, 0), c(2, 1)))
})

# ------------------------------------------------------------------------------
context("test-min")

test_that("`min()` returns a length 1 vector for 1D", {
expect_equal(min(rray(5:1)), rray(1L))
expect_equal(min(rray(5:1 + 0)), rray(1))
})

test_that("`min()` returns a length 1 vector for 2D", {
x <- rray(c(2, 4, 5, 2), c(2, 2))
expect_equal(
min(x),
rray(2)
)
})

test_that("`min()` returns a length 1 vector for 3D", {
x <- rray(c(2, 4, 5, 2), c(2, 1, 2))
expect_equal(
min(x),
rray(2)
)
})

test_that("vctrs `min()` ignores input in `...`", {
expect_equal(min(rray(2), 1), rray(2))
})

# TODO - Add tests after this is fixed
# https://github.com/r-lib/vctrs/pull/329

# test_that("NAs are removed", {
# min(rray(c(NA, 2)), na.rm = TRUE)
# })

# ------------------------------------------------------------------------------
context("test-max")

test_that("`max()` returns a length 1 vector for 1D", {
expect_equal(max(rray(5:1)), rray(5L))
expect_equal(max(rray(5:1 + 0)), rray(5))
})

test_that("`max()` returns a length 1 vector for 2D", {
x <- rray(c(2, 4, 5, 2), c(2, 2))
expect_equal(
max(x),
rray(5)
)
})

test_that("`max()` returns a length 1 vector for 3D", {
x <- rray(c(2, 4, 5, 2), c(2, 1, 2))
expect_equal(
max(x),
rray(5)
)
})

test_that("vctrs `max()` ignores input in `...`", {
expect_equal(max(rray(2), 1), rray(2))
})

# TODO - Add tests after this is fixed
# https://github.com/r-lib/vctrs/pull/329

# test_that("NAs are removed", {
# max(rray(c(NA, 2)), na.rm = TRUE)
# })

# ------------------------------------------------------------------------------
context("test-base-any")

test_that("returns a single value with shaped arrays", {
expect_equal(any(rray(c(TRUE, FALSE), c(2, 2))), TRUE)
expect_equal(any(rray(c(FALSE, FALSE), c(2, 2))), FALSE)
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(any(rray(c(NA, 1L)), na.rm = FALSE), TRUE)
expect_equal(any(rray(c(NA, 0L)), na.rm = FALSE), FALSE)
})

# ------------------------------------------------------------------------------
context("test-base-all")

test_that("returns a single value with shaped arrays", {
expect_equal(all(rray(c(TRUE, FALSE), c(2, 2))), FALSE)
expect_equal(all(rray(c(TRUE, TRUE), c(2, 2))), TRUE)
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(all(rray(c(NA, 1L, 0L)), na.rm = FALSE), FALSE)
expect_equal(all(rray(c(NA, 1L, 1L)), na.rm = FALSE), TRUE)
})

# ------------------------------------------------------------------------------
context("test-base-range")

test_that("returns same values as base R", {
x <- rray(c(TRUE, FALSE), c(2, 2))
expect_equal(range(x), range(vec_data(x)))
expect_equal(range(x, x), range(vec_data(x), vec_data(x)))
expect_equal(range(x, 5), range(vec_data(x), 5))
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(range(rray(c(NA, 1L)), na.rm = FALSE), c(1L, 1L))
})

# ------------------------------------------------------------------------------
context("test-base-prod")

test_that("returns same values as base R", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(prod(x), prod(vec_data(x)))
expect_equal(prod(x, x), prod(vec_data(x), vec_data(x)))
})

test_that("broadcasts input using vctrs", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(prod(x, 5), prod(x, matrix(5, c(1, 2))))
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(prod(rray(c(NA, 1L)), na.rm = FALSE), 1)
})

# ------------------------------------------------------------------------------
context("test-base-sum")

test_that("returns same values as base R", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(sum(x), sum(vec_data(x)))
expect_equal(sum(x, x), sum(vec_data(x), vec_data(x)))
})

test_that("broadcasts input using vctrs", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(sum(x, 5), sum(x, matrix(5, c(1, 2))))
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(sum(rray(c(NA, 1L)), na.rm = FALSE), 1)
})

0 comments on commit 2a2985e

Please sign in to comment.