Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache vec_proxy results as stopgap for r-lib/vctrs#1411 #179

Merged
merged 1 commit into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/as_draws_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ as_draws_rvars.draws_matrix <- function(x, ...) {
out <- .as_draws_rvars(rvars_list, ...)
.nchains <- nchains(x)
for (i in seq_along(out)) {
attr(out[[i]], "nchains") <- .nchains
nchains_rvar(out[[i]]) <- .nchains
}
out
}
Expand Down
8 changes: 8 additions & 0 deletions R/draws-index.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ nchains.draws_rvars <- function(x) {
nchains.rvar <- function(x) {
attr(x, "nchains") %||% 1L
}
# for internal use only currently: if you are setting the nchains
# attribute on an rvar, ALWAYS use this function so that the proxy
# cache is invalidated
`nchains_rvar<-` <- function(x, value) {
attr(x, "nchains") <- value
invalidate_rvar_cache(x)
}


#' @rdname draws-index
#' @export
Expand Down
2 changes: 1 addition & 1 deletion R/merge_chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ merge_chains.draws_list <- function(x, ...) {
#' @rdname merge_chains
#' @export
merge_chains.rvar <- function(x, ...) {
attr(x, "nchains") <- 1L
nchains_rvar(x) <- 1L
x
}

Expand Down
16 changes: 12 additions & 4 deletions R/rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ new_rvar <- function(x = double(), .nchains = 1L) {
list(),
draws = x,
nchains = .nchains,
class = c("rvar", "vctrs_vctr", "list")
class = c("rvar", "vctrs_vctr", "list"),
cache = new.env(parent = emptyenv())
)
}

Expand Down Expand Up @@ -207,12 +208,13 @@ draws_of <- function(x, with_chains = FALSE) {

if (with_chains) {
draws <- drop_chain_dim(value)
attr(x, "nchains") <- dim(value)[[2]] %||% 1L
nchains_rvar(x) <- dim(value)[[2]] %||% 1L
} else {
draws <- value
}
attr(x, "draws") <- cleanup_draw_dims(draws)

x <- invalidate_rvar_cache(x)
x
}

Expand Down Expand Up @@ -302,7 +304,13 @@ all.equal.rvar <- function(target, current, ...) {
))
}

object_result <- all.equal(unclass(target), unclass(current), ...)
# ignore cache in comparison
.target <- unclass(target)
attr(.target, "cache") <- NULL
.current <- unclass(current)
attr(.current, "cache") <- NULL

object_result <- all.equal(.target, .current, ...)
if (!isTRUE(object_result)) {
result = c(result, object_result)
}
Expand Down Expand Up @@ -392,7 +400,7 @@ conform_rvar_nchains <- function(rvars) {
.nchains <- Reduce(nchains2_common, nchains_or_null) %||% 1L

for (i in seq_along(rvars)) {
attr(rvars[[i]], "nchains") <- .nchains
nchains_rvar(rvars[[i]]) <- .nchains
}

rvars
Expand Down
31 changes: 25 additions & 6 deletions R/rvar-cast.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ as_rvar <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
.ndraws <- ndraws(out)
nchains <- as_one_integer(nchains)
check_nchains_compat_with_ndraws(nchains, .ndraws)
attr(out, "nchains") <- nchains
nchains_rvar(out) <- nchains
}

out
Expand Down Expand Up @@ -131,15 +131,29 @@ as_tibble.rvar <- function(x, ...) {

# vctrs proxy / restore --------------------------------------------------------

invalidate_rvar_cache = function(x) {
attr(x, "cache") <- new.env(parent = emptyenv())
x
}

#' @importFrom vctrs vec_proxy vec_chop
#' @export
vec_proxy.rvar = function(x, ...) {
# TODO: probably could do something more efficient here and for restore
.draws = draws_of(x)
out <- vec_chop(aperm(.draws, c(2, 1, seq_along(dim(.draws))[c(-1,-2)])))
for (i in seq_along(out)) {
attr(out[[i]], "nchains") <- nchains(x)
# In the meantime, using caching to help with algorithms that call vec_proxy
# repeatedly. See https://github.com/r-lib/vctrs/issues/1411

out <- attr(x, "cache")$vec_proxy
if (is.null(out)) {
# proxy is not in the cache, calculate it and store it in the cache
.draws = draws_of(x)
out <- vec_chop(aperm(.draws, c(2, 1, seq_along(dim(.draws))[c(-1,-2)])))
for (i in seq_along(out)) {
attr(out[[i]], "nchains") <- nchains(x)
}
attr(x, "cache")$vec_proxy <- out
}

out
}

Expand Down Expand Up @@ -173,7 +187,12 @@ vec_restore.rvar <- function(x, ...) {
nchains_or_null <- lapply(x, function(x) if (dim(x)[[2]] %||% 1 == 1) NULL else attr(x, "nchains"))
.nchains <- Reduce(nchains2_common, nchains_or_null) %||% 1L

new_rvar(.draws, .nchains = .nchains)
out <- new_rvar(.draws, .nchains = .nchains)

# since we've already spent time calculating it, save the proxy in the cache
attr(out, "cache")$vec_proxy <- x

out
}


Expand Down
2 changes: 1 addition & 1 deletion R/rvar-print.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ str.rvar <- function(
}
}
str_attr(attributes(draws_of(object)), "draws_of(*)", c("names", "dim", "dimnames", "class"))
str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains"))
str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains", "cache"))
}

invisible(NULL)
Expand Down
2 changes: 1 addition & 1 deletion R/subset_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ subset_dims <- function(x, ...) {
slice_index <- chain_ids %in% chain
for (i in seq_along(x)) {
draws_of(x[[i]]) <- vec_slice(draws_of(x[[i]]), slice_index)
attr(x[[i]], "nchains") <- nchains
nchains_rvar(x[[i]]) <- nchains
}
}
if (!is.null(iteration)) {
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ test_that("rvars work in tibbles", {
x = rvar_from_array(x_array)
df = tibble::tibble(x, y = x + 1)

expect_identical(df$x, x)
expect_identical(df$y, rvar_from_array(x_array + 1))
expect_identical(dplyr::mutate(df, z = x)$z, x)
expect_equal(df$x, x)
expect_equal(df$y, rvar_from_array(x_array + 1))
expect_equal(dplyr::mutate(df, z = x)$z, x)

expect_equal(dplyr::mutate(df, z = x * 2)$z, rvar_from_array(x_array * 2))
expect_equal(
Expand Down
90 changes: 45 additions & 45 deletions tests/testthat/test-rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ test_that("math operators works", {
y_array = array(c(2:13,12:1), dim = c(4,2,3))
y = new_rvar(y_array)

expect_identical(log(x), new_rvar(log(x_array)))
expect_equal(log(x), new_rvar(log(x_array)))

expect_identical(-x, new_rvar(-x_array))
expect_equal(-x, new_rvar(-x_array))

expect_identical(x + 2, new_rvar(x_array + 2))
expect_identical(2 + x, new_rvar(x_array + 2))
expect_identical(x + y, new_rvar(x_array + y_array))
expect_equal(x + 2, new_rvar(x_array + 2))
expect_equal(2 + x, new_rvar(x_array + 2))
expect_equal(x + y, new_rvar(x_array + y_array))

expect_identical(x - 2, new_rvar(x_array - 2))
expect_identical(2 - x, new_rvar(2 - x_array))
expect_identical(x - y, new_rvar(x_array - y_array))
expect_equal(x - 2, new_rvar(x_array - 2))
expect_equal(2 - x, new_rvar(2 - x_array))
expect_equal(x - y, new_rvar(x_array - y_array))

expect_identical(x * 2, new_rvar(x_array * 2))
expect_identical(2 * x, new_rvar(x_array * 2))
expect_identical(x * y, new_rvar(x_array * y_array))
expect_equal(x * 2, new_rvar(x_array * 2))
expect_equal(2 * x, new_rvar(x_array * 2))
expect_equal(x * y, new_rvar(x_array * y_array))

expect_identical(x / 2, new_rvar(x_array / 2))
expect_identical(2 / x, new_rvar(2 / x_array))
expect_identical(x / y, new_rvar(x_array / y_array))
expect_equal(x / 2, new_rvar(x_array / 2))
expect_equal(2 / x, new_rvar(2 / x_array))
expect_equal(x / y, new_rvar(x_array / y_array))

expect_identical(x ^ 2, new_rvar((x_array) ^ 2))
expect_identical(2 ^ x, new_rvar(2 ^ (x_array)))
expect_identical(x ^ y, new_rvar(x_array ^ y_array))
expect_equal(x ^ 2, new_rvar((x_array) ^ 2))
expect_equal(2 ^ x, new_rvar(2 ^ (x_array)))
expect_equal(x ^ y, new_rvar(x_array ^ y_array))

# ensure broadcasting of constants retains shape
z2 <- new_rvar(array(1, dim = c(1,1)))
Expand All @@ -42,13 +42,13 @@ test_that("logical operators work", {
x = as_rvar(x_array)
y = as_rvar(y_array)

expect_identical(x | y_array, as_rvar(x_array | y_array))
expect_identical(y_array | x, as_rvar(x_array | y_array))
expect_identical(x | y, as_rvar(x_array | y_array))
expect_equal(x | y_array, as_rvar(x_array | y_array))
expect_equal(y_array | x, as_rvar(x_array | y_array))
expect_equal(x | y, as_rvar(x_array | y_array))

expect_identical(x & y_array, as_rvar(x_array & y_array))
expect_identical(y_array & x, as_rvar(x_array & y_array))
expect_identical(x & y, as_rvar(x_array & y_array))
expect_equal(x & y_array, as_rvar(x_array & y_array))
expect_equal(y_array & x, as_rvar(x_array & y_array))
expect_equal(x & y, as_rvar(x_array & y_array))
})

test_that("comparison operators work", {
Expand All @@ -57,29 +57,29 @@ test_that("comparison operators work", {
y_array = array(c(2:13,12:1), dim = c(4,2,3))
y = new_rvar(y_array)

expect_identical(x < 5, new_rvar(x_array < 5))
expect_identical(5 < x, new_rvar(5 < x_array))
expect_identical(x < y, new_rvar(x_array < y_array))
expect_equal(x < 5, new_rvar(x_array < 5))
expect_equal(5 < x, new_rvar(5 < x_array))
expect_equal(x < y, new_rvar(x_array < y_array))

expect_identical(x <= 5, new_rvar(x_array <= 5))
expect_identical(5 <= x, new_rvar(5 <= x_array))
expect_identical(x <= y, new_rvar(x_array <= y_array))
expect_equal(x <= 5, new_rvar(x_array <= 5))
expect_equal(5 <= x, new_rvar(5 <= x_array))
expect_equal(x <= y, new_rvar(x_array <= y_array))

expect_identical(x > 5, new_rvar(x_array > 5))
expect_identical(5 > x, new_rvar(5 > x_array))
expect_identical(x > y, new_rvar(x_array > y_array))
expect_equal(x > 5, new_rvar(x_array > 5))
expect_equal(5 > x, new_rvar(5 > x_array))
expect_equal(x > y, new_rvar(x_array > y_array))

expect_identical(x >= 5, new_rvar(x_array >= 5))
expect_identical(5 >= x, new_rvar(5 >= x_array))
expect_identical(x >= y, new_rvar(x_array >= y_array))
expect_equal(x >= 5, new_rvar(x_array >= 5))
expect_equal(5 >= x, new_rvar(5 >= x_array))
expect_equal(x >= y, new_rvar(x_array >= y_array))

expect_identical(x == 5, new_rvar(x_array == 5))
expect_identical(5 == x, new_rvar(5 == x_array))
expect_identical(x == y, new_rvar(x_array == y_array))
expect_equal(x == 5, new_rvar(x_array == 5))
expect_equal(5 == x, new_rvar(5 == x_array))
expect_equal(x == y, new_rvar(x_array == y_array))

expect_identical(x != 5, new_rvar(x_array != 5))
expect_identical(5 != x, new_rvar(5 != x_array))
expect_identical(x != y, new_rvar(x_array != y_array))
expect_equal(x != 5, new_rvar(x_array != 5))
expect_equal(5 != x, new_rvar(5 != x_array))
expect_equal(x != y, new_rvar(x_array != y_array))
})

test_that("functions in the Math generic with extra arguments work", {
Expand Down Expand Up @@ -131,7 +131,7 @@ test_that("matrix multiplication works", {
x_array[3,,] %*% y_array[3,,],
x_array[4,,] %*% y_array[4,,]
))
expect_identical(x %**% y, xy_ref)
expect_equal(x %**% y, xy_ref)


x_array = array(1:6, dim = c(2,3))
Expand All @@ -143,20 +143,20 @@ test_that("matrix multiplication works", {
x_array[1,] %*% y_array[1,],
x_array[2,] %*% y_array[2,]
))
expect_identical(x %**% y, xy_ref)
expect_equal(x %**% y, xy_ref)

# automatic promotion to row/col vector of numeric vectors
x_meany_ref = new_rvar(abind::abind(along = 0,
x_array[1,] %*% colMeans(y_array),
x_array[2,] %*% colMeans(y_array)
))
expect_identical(x %**% colMeans(y_array), x_meany_ref)
expect_equal(x %**% colMeans(y_array), x_meany_ref)

meanx_y_ref = new_rvar(abind::abind(along = 0,
colMeans(x_array) %*% y_array[1,],
colMeans(x_array) %*% y_array[2,]
))
expect_identical(colMeans(x_array) %**% y, meanx_y_ref)
expect_equal(colMeans(x_array) %**% y, meanx_y_ref)

# dimension name preservation
m1 <- as_rvar(diag(1:3))
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-rvar-rfun.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_that("rdo works", {
x_array[3,,] %*% y_array[3,,],
x_array[4,,] %*% y_array[4,,]
))
expect_identical(rdo(x %*% y), xy_ref)
expect_equal(rdo(x %*% y), xy_ref)
})

test_that("rfun works", {
Expand All @@ -25,7 +25,7 @@ test_that("rfun works", {
x_array[3,,] %*% y_array[3,,],
x_array[4,,] %*% y_array[4,,]
))
expect_identical(rfun(function(a,b) a %*% b)(x, y), xy_ref)
expect_equal(rfun(function(a,b) a %*% b)(x, y), xy_ref)
})

test_that("rvar_rng works", {
Expand Down
Loading