Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summary group #160

Merged
merged 4 commits into from May 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions NAMESPACE
Expand Up @@ -18,8 +18,6 @@ S3method("dim_names<-",vctrs_rray)
S3method("dimnames<-",vctrs_rray)
S3method("names<-",vctrs_rray)
S3method("|",vctrs_rray)
S3method(all,vctrs_rray)
S3method(any,vctrs_rray)
S3method(aperm,vctrs_rray)
S3method(as.array,vctrs_rray)
S3method(as.double,vctrs_rray)
Expand Down Expand Up @@ -140,6 +138,8 @@ S3method(vec_type2.vctrs_rray_lgl,vctrs_rray_dbl)
S3method(vec_type2.vctrs_rray_lgl,vctrs_rray_int)
S3method(vec_type2.vctrs_rray_lgl,vctrs_rray_lgl)
S3method(vec_type2.vctrs_rray_lgl,vctrs_unspecified)
S3method(xtfrm,vctrs_rray)
S3method(xtfrm,vctrs_rray_lgl)
export("%b*%")
export("%b+%")
export("%b-%")
Expand Down Expand Up @@ -334,6 +334,7 @@ importFrom(vctrs,vec_duplicate_id)
importFrom(vctrs,vec_math)
importFrom(vctrs,vec_math_base)
importFrom(vctrs,vec_na)
importFrom(vctrs,vec_proxy_compare)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_restore)
Expand Down
1 change: 1 addition & 0 deletions R/aaa.R
Expand Up @@ -22,6 +22,7 @@
#' @importFrom vctrs new_rcrd
#' @importFrom vctrs field
#' @importFrom vctrs vec_split
#' @importFrom vctrs vec_proxy_compare
#'
#' @importFrom vctrs vec_ptype_full
#' @importFrom vctrs vec_ptype_abbr
Expand Down
7 changes: 7 additions & 0 deletions R/compat-vctrs-math.R
Expand Up @@ -60,6 +60,13 @@ rray_math_unary_op_switch <- function(fun) {
"is.infinite" = rray_is_infinite,
"is.finite" = rray_is_finite,

# summary
"all" = rray_all_vctrs_wrapper,
"any" = rray_any_vctrs_wrapper,
"range" = rray_range_vctrs_wrapper,
"prod" = rray_prod_vctrs_wrapper,
"sum" = rray_sum_vctrs_wrapper,

glubort("Unary math function not known: {fun}.")
)
}
Expand Down
41 changes: 0 additions & 41 deletions R/logical.R
Expand Up @@ -11,9 +11,6 @@
#' @param ... A single rray. An error is currently thrown if more than one
#' input is passed here.
#'
#' @param na.rm Should `NA` values be removed? Currently only `FALSE` is
#' allowed.
#'
#' @param axes An integer vector specifying the axes to reduce over.
#' `1` reduces the number of rows to `1`, performing the reduction along the
#' way. `2` does the same, but with the columns, and so on for higher
Expand Down Expand Up @@ -117,25 +114,6 @@ rray_logical_not <- function(x) {

# ------------------------------------------------------------------------------

# `any()` should always be on a flattened version of the input to maintain
# backwards compat with base R. `rray_any()` should only handle 1 input
# but should be able to look along an axis.
# TODO - Think more about this.

#' @rdname rray-logical
#' @export
`any.vctrs_rray` <- function(..., na.rm = FALSE) {

if (!identical(na.rm, FALSE)) {
abort("`na.rm` currently must be `FALSE` in `any()` for rrays.")
}

x <- map(list2(...), as.vector)
x <- vec_c(!!! x)

vec_math_base("any", x)
}

#' @rdname rray-logical
#' @export
rray_any <- function(x, axes = NULL) {
Expand All @@ -162,25 +140,6 @@ rray_any <- function(x, axes = NULL) {

# ------------------------------------------------------------------------------

# `all()` should always be on a flattened version of the input to maintain
# backwards compat with base R. `rray_all()` should only handle 1 input
# but should be able to look along an axis.
# TODO - Think more about this.

#' @rdname rray-logical
#' @export
`all.vctrs_rray` <- function(..., na.rm = FALSE) {

if (!identical(na.rm, FALSE)) {
abort("`na.rm` currently must be `FALSE` in `all()` for rrays.")
}

x <- map(list2(...), as.vector)
x <- vec_c(!!! x)

vec_math_base("all", x)
}

#' @rdname rray-logical
#' @export
rray_all <- function(x, axes = NULL) {
Expand Down
58 changes: 58 additions & 0 deletions R/summary-group.R
@@ -0,0 +1,58 @@
# Summary generics are:
# - all(), any()
# - sum(), prod()
# - min(), max()
# - range()

# All of them dispatch on the first argument, as described in `?Summary`

# ------------------------------------------------------------------------------

# - vctrs:::min.vctrs_vctr() does the right thing because of `xtfrm.vctrs_vctr()`
# - vctrs:::max.vctrs_vctr() does the right thing because of `xtfrm.vctrs_vctr()`

# However, vctrs only automatically implements `xtfrm()` for integer and double
# arrays, so logical ones need special treatment. Because of this, just
# implement a simple xtfrm() method that calls `vec_proxy_compare()` like vctrs

#' @export
xtfrm.vctrs_rray <- function(x) {
vec_proxy_compare(x)
}

#' @export
xtfrm.vctrs_rray_lgl <- function(x) {
rray_cast_inner(vec_proxy_compare(x), integer())
}

# ------------------------------------------------------------------------------

# This is a base R compatible version of `all()` and `any()`.
# It is used in vec_math() dispatch

# Note that `vctrs:::Summary.vctrs_vctr()` is how this is passed through,
# and `na.rm = TRUE` no matter what there!

rray_all_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("all", vec_data(x), na.rm = na.rm)
}

rray_any_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("any", vec_data(x), na.rm = na.rm)
}

# ------------------------------------------------------------------------------

rray_range_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("range", vec_data(x), na.rm = na.rm)
}

# ------------------------------------------------------------------------------

rray_prod_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("prod", vec_data(x), na.rm = na.rm)
}

rray_sum_vctrs_wrapper <- function(x, na.rm) {
vec_math_base("sum", vec_data(x), na.rm = na.rm)
}
15 changes: 3 additions & 12 deletions man/rray-logical.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

166 changes: 166 additions & 0 deletions tests/testthat/test-summary-group.R
@@ -0,0 +1,166 @@
# ------------------------------------------------------------------------------
context("test-xtfrm")

test_that("xtfrm() returns proxy objects", {
x <- rray(1:2, c(2, 1))
expect_equal(xtfrm(x), vec_data(x))

x <- rray(c(1, 2), c(2, 1))
expect_equal(xtfrm(x), vec_data(x))
})

test_that("xtfrm() works for 3D", {
x <- rray(1:6, c(2, 1, 3))
expect_equal(xtfrm(x), vec_data(x))
})

test_that("xtfrm() for logicals returns integers", {
x <- rray(c(TRUE, FALSE), c(2, 1))
expect_equal(xtfrm(x), new_matrix(c(1, 0), c(2, 1)))
})

# ------------------------------------------------------------------------------
context("test-min")

test_that("`min()` returns a length 1 vector for 1D", {
expect_equal(min(rray(5:1)), rray(1L))
expect_equal(min(rray(5:1 + 0)), rray(1))
})

test_that("`min()` returns a length 1 vector for 2D", {
x <- rray(c(2, 4, 5, 2), c(2, 2))
expect_equal(
min(x),
rray(2)
)
})

test_that("`min()` returns a length 1 vector for 3D", {
x <- rray(c(2, 4, 5, 2), c(2, 1, 2))
expect_equal(
min(x),
rray(2)
)
})

test_that("vctrs `min()` ignores input in `...`", {
expect_equal(min(rray(2), 1), rray(2))
})

# TODO - Add tests after this is fixed
# https://github.com/r-lib/vctrs/pull/329

# test_that("NAs are removed", {
# min(rray(c(NA, 2)), na.rm = TRUE)
# })

# ------------------------------------------------------------------------------
context("test-max")

test_that("`max()` returns a length 1 vector for 1D", {
expect_equal(max(rray(5:1)), rray(5L))
expect_equal(max(rray(5:1 + 0)), rray(5))
})

test_that("`max()` returns a length 1 vector for 2D", {
x <- rray(c(2, 4, 5, 2), c(2, 2))
expect_equal(
max(x),
rray(5)
)
})

test_that("`max()` returns a length 1 vector for 3D", {
x <- rray(c(2, 4, 5, 2), c(2, 1, 2))
expect_equal(
max(x),
rray(5)
)
})

test_that("vctrs `max()` ignores input in `...`", {
expect_equal(max(rray(2), 1), rray(2))
})

# TODO - Add tests after this is fixed
# https://github.com/r-lib/vctrs/pull/329

# test_that("NAs are removed", {
# max(rray(c(NA, 2)), na.rm = TRUE)
# })

# ------------------------------------------------------------------------------
context("test-base-any")

test_that("returns a single value with shaped arrays", {
expect_equal(any(rray(c(TRUE, FALSE), c(2, 2))), TRUE)
expect_equal(any(rray(c(FALSE, FALSE), c(2, 2))), FALSE)
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(any(rray(c(NA, 1L)), na.rm = FALSE), TRUE)
expect_equal(any(rray(c(NA, 0L)), na.rm = FALSE), FALSE)
})

# ------------------------------------------------------------------------------
context("test-base-all")

test_that("returns a single value with shaped arrays", {
expect_equal(all(rray(c(TRUE, FALSE), c(2, 2))), FALSE)
expect_equal(all(rray(c(TRUE, TRUE), c(2, 2))), TRUE)
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(all(rray(c(NA, 1L, 0L)), na.rm = FALSE), FALSE)
expect_equal(all(rray(c(NA, 1L, 1L)), na.rm = FALSE), TRUE)
})

# ------------------------------------------------------------------------------
context("test-base-range")

test_that("returns same values as base R", {
x <- rray(c(TRUE, FALSE), c(2, 2))
expect_equal(range(x), range(vec_data(x)))
expect_equal(range(x, x), range(vec_data(x), vec_data(x)))
expect_equal(range(x, 5), range(vec_data(x), 5))
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(range(rray(c(NA, 1L)), na.rm = FALSE), c(1L, 1L))
})

# ------------------------------------------------------------------------------
context("test-base-prod")

test_that("returns same values as base R", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(prod(x), prod(vec_data(x)))
expect_equal(prod(x, x), prod(vec_data(x), vec_data(x)))
})

test_that("broadcasts input using vctrs", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(prod(x, 5), prod(x, matrix(5, c(1, 2))))
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(prod(rray(c(NA, 1L)), na.rm = FALSE), 1)
})

# ------------------------------------------------------------------------------
context("test-base-sum")

test_that("returns same values as base R", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(sum(x), sum(vec_data(x)))
expect_equal(sum(x, x), sum(vec_data(x), vec_data(x)))
})

test_that("broadcasts input using vctrs", {
x <- rray(c(5, 6), c(2, 2))
expect_equal(sum(x, 5), sum(x, matrix(5, c(1, 2))))
})

test_that("always uses `na.rm = TRUE`", {
expect_equal(sum(rray(c(NA, 1L)), na.rm = FALSE), 1)
})