Skip to content

Commit

Permalink
Merge pull request #179 from stan-dev/rvar-proxy-cache
Browse files Browse the repository at this point in the history
cache vec_proxy results as stopgap for r-lib/vctrs#1411
  • Loading branch information
paul-buerkner committed Jul 14, 2021
2 parents e73b53e + c5036ff commit befdfbd
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 125 deletions.
2 changes: 1 addition & 1 deletion R/as_draws_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ as_draws_rvars.draws_matrix <- function(x, ...) {
out <- .as_draws_rvars(rvars_list, ...)
.nchains <- nchains(x)
for (i in seq_along(out)) {
attr(out[[i]], "nchains") <- .nchains
nchains_rvar(out[[i]]) <- .nchains
}
out
}
Expand Down
8 changes: 8 additions & 0 deletions R/draws-index.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ nchains.draws_rvars <- function(x) {
nchains.rvar <- function(x) {
attr(x, "nchains") %||% 1L
}
# for internal use only currently: if you are setting the nchains
# attribute on an rvar, ALWAYS use this function so that the proxy
# cache is invalidated
`nchains_rvar<-` <- function(x, value) {
attr(x, "nchains") <- value
invalidate_rvar_cache(x)
}


#' @rdname draws-index
#' @export
Expand Down
2 changes: 1 addition & 1 deletion R/merge_chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ merge_chains.draws_list <- function(x, ...) {
#' @rdname merge_chains
#' @export
merge_chains.rvar <- function(x, ...) {
attr(x, "nchains") <- 1L
nchains_rvar(x) <- 1L
x
}

Expand Down
16 changes: 12 additions & 4 deletions R/rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ new_rvar <- function(x = double(), .nchains = 1L) {
list(),
draws = x,
nchains = .nchains,
class = c("rvar", "vctrs_vctr", "list")
class = c("rvar", "vctrs_vctr", "list"),
cache = new.env(parent = emptyenv())
)
}

Expand Down Expand Up @@ -207,12 +208,13 @@ draws_of <- function(x, with_chains = FALSE) {

if (with_chains) {
draws <- drop_chain_dim(value)
attr(x, "nchains") <- dim(value)[[2]] %||% 1L
nchains_rvar(x) <- dim(value)[[2]] %||% 1L
} else {
draws <- value
}
attr(x, "draws") <- cleanup_draw_dims(draws)

x <- invalidate_rvar_cache(x)
x
}

Expand Down Expand Up @@ -302,7 +304,13 @@ all.equal.rvar <- function(target, current, ...) {
))
}

object_result <- all.equal(unclass(target), unclass(current), ...)
# ignore cache in comparison
.target <- unclass(target)
attr(.target, "cache") <- NULL
.current <- unclass(current)
attr(.current, "cache") <- NULL

object_result <- all.equal(.target, .current, ...)
if (!isTRUE(object_result)) {
result = c(result, object_result)
}
Expand Down Expand Up @@ -392,7 +400,7 @@ conform_rvar_nchains <- function(rvars) {
.nchains <- Reduce(nchains2_common, nchains_or_null) %||% 1L

for (i in seq_along(rvars)) {
attr(rvars[[i]], "nchains") <- .nchains
nchains_rvar(rvars[[i]]) <- .nchains
}

rvars
Expand Down
31 changes: 25 additions & 6 deletions R/rvar-cast.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ as_rvar <- function(x, dim = NULL, dimnames = NULL, nchains = NULL) {
.ndraws <- ndraws(out)
nchains <- as_one_integer(nchains)
check_nchains_compat_with_ndraws(nchains, .ndraws)
attr(out, "nchains") <- nchains
nchains_rvar(out) <- nchains
}

out
Expand Down Expand Up @@ -131,15 +131,29 @@ as_tibble.rvar <- function(x, ...) {

# vctrs proxy / restore --------------------------------------------------------

invalidate_rvar_cache = function(x) {
attr(x, "cache") <- new.env(parent = emptyenv())
x
}

#' @importFrom vctrs vec_proxy vec_chop
#' @export
vec_proxy.rvar = function(x, ...) {
# TODO: probably could do something more efficient here and for restore
.draws = draws_of(x)
out <- vec_chop(aperm(.draws, c(2, 1, seq_along(dim(.draws))[c(-1,-2)])))
for (i in seq_along(out)) {
attr(out[[i]], "nchains") <- nchains(x)
# In the meantime, using caching to help with algorithms that call vec_proxy
# repeatedly. See https://github.com/r-lib/vctrs/issues/1411

out <- attr(x, "cache")$vec_proxy
if (is.null(out)) {
# proxy is not in the cache, calculate it and store it in the cache
.draws = draws_of(x)
out <- vec_chop(aperm(.draws, c(2, 1, seq_along(dim(.draws))[c(-1,-2)])))
for (i in seq_along(out)) {
attr(out[[i]], "nchains") <- nchains(x)
}
attr(x, "cache")$vec_proxy <- out
}

out
}

Expand Down Expand Up @@ -173,7 +187,12 @@ vec_restore.rvar <- function(x, ...) {
nchains_or_null <- lapply(x, function(x) if (dim(x)[[2]] %||% 1 == 1) NULL else attr(x, "nchains"))
.nchains <- Reduce(nchains2_common, nchains_or_null) %||% 1L

new_rvar(.draws, .nchains = .nchains)
out <- new_rvar(.draws, .nchains = .nchains)

# since we've already spent time calculating it, save the proxy in the cache
attr(out, "cache")$vec_proxy <- x

out
}


Expand Down
2 changes: 1 addition & 1 deletion R/rvar-print.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ str.rvar <- function(
}
}
str_attr(attributes(draws_of(object)), "draws_of(*)", c("names", "dim", "dimnames", "class"))
str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains"))
str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains", "cache"))
}

invisible(NULL)
Expand Down
2 changes: 1 addition & 1 deletion R/subset_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ subset_dims <- function(x, ...) {
slice_index <- chain_ids %in% chain
for (i in seq_along(x)) {
draws_of(x[[i]]) <- vec_slice(draws_of(x[[i]]), slice_index)
attr(x[[i]], "nchains") <- nchains
nchains_rvar(x[[i]]) <- nchains
}
}
if (!is.null(iteration)) {
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ test_that("rvars work in tibbles", {
x = rvar_from_array(x_array)
df = tibble::tibble(x, y = x + 1)

expect_identical(df$x, x)
expect_identical(df$y, rvar_from_array(x_array + 1))
expect_identical(dplyr::mutate(df, z = x)$z, x)
expect_equal(df$x, x)
expect_equal(df$y, rvar_from_array(x_array + 1))
expect_equal(dplyr::mutate(df, z = x)$z, x)

expect_equal(dplyr::mutate(df, z = x * 2)$z, rvar_from_array(x_array * 2))
expect_equal(
Expand Down
90 changes: 45 additions & 45 deletions tests/testthat/test-rvar-math.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ test_that("math operators works", {
y_array = array(c(2:13,12:1), dim = c(4,2,3))
y = new_rvar(y_array)

expect_identical(log(x), new_rvar(log(x_array)))
expect_equal(log(x), new_rvar(log(x_array)))

expect_identical(-x, new_rvar(-x_array))
expect_equal(-x, new_rvar(-x_array))

expect_identical(x + 2, new_rvar(x_array + 2))
expect_identical(2 + x, new_rvar(x_array + 2))
expect_identical(x + y, new_rvar(x_array + y_array))
expect_equal(x + 2, new_rvar(x_array + 2))
expect_equal(2 + x, new_rvar(x_array + 2))
expect_equal(x + y, new_rvar(x_array + y_array))

expect_identical(x - 2, new_rvar(x_array - 2))
expect_identical(2 - x, new_rvar(2 - x_array))
expect_identical(x - y, new_rvar(x_array - y_array))
expect_equal(x - 2, new_rvar(x_array - 2))
expect_equal(2 - x, new_rvar(2 - x_array))
expect_equal(x - y, new_rvar(x_array - y_array))

expect_identical(x * 2, new_rvar(x_array * 2))
expect_identical(2 * x, new_rvar(x_array * 2))
expect_identical(x * y, new_rvar(x_array * y_array))
expect_equal(x * 2, new_rvar(x_array * 2))
expect_equal(2 * x, new_rvar(x_array * 2))
expect_equal(x * y, new_rvar(x_array * y_array))

expect_identical(x / 2, new_rvar(x_array / 2))
expect_identical(2 / x, new_rvar(2 / x_array))
expect_identical(x / y, new_rvar(x_array / y_array))
expect_equal(x / 2, new_rvar(x_array / 2))
expect_equal(2 / x, new_rvar(2 / x_array))
expect_equal(x / y, new_rvar(x_array / y_array))

expect_identical(x ^ 2, new_rvar((x_array) ^ 2))
expect_identical(2 ^ x, new_rvar(2 ^ (x_array)))
expect_identical(x ^ y, new_rvar(x_array ^ y_array))
expect_equal(x ^ 2, new_rvar((x_array) ^ 2))
expect_equal(2 ^ x, new_rvar(2 ^ (x_array)))
expect_equal(x ^ y, new_rvar(x_array ^ y_array))

# ensure broadcasting of constants retains shape
z2 <- new_rvar(array(1, dim = c(1,1)))
Expand All @@ -42,13 +42,13 @@ test_that("logical operators work", {
x = as_rvar(x_array)
y = as_rvar(y_array)

expect_identical(x | y_array, as_rvar(x_array | y_array))
expect_identical(y_array | x, as_rvar(x_array | y_array))
expect_identical(x | y, as_rvar(x_array | y_array))
expect_equal(x | y_array, as_rvar(x_array | y_array))
expect_equal(y_array | x, as_rvar(x_array | y_array))
expect_equal(x | y, as_rvar(x_array | y_array))

expect_identical(x & y_array, as_rvar(x_array & y_array))
expect_identical(y_array & x, as_rvar(x_array & y_array))
expect_identical(x & y, as_rvar(x_array & y_array))
expect_equal(x & y_array, as_rvar(x_array & y_array))
expect_equal(y_array & x, as_rvar(x_array & y_array))
expect_equal(x & y, as_rvar(x_array & y_array))
})

test_that("comparison operators work", {
Expand All @@ -57,29 +57,29 @@ test_that("comparison operators work", {
y_array = array(c(2:13,12:1), dim = c(4,2,3))
y = new_rvar(y_array)

expect_identical(x < 5, new_rvar(x_array < 5))
expect_identical(5 < x, new_rvar(5 < x_array))
expect_identical(x < y, new_rvar(x_array < y_array))
expect_equal(x < 5, new_rvar(x_array < 5))
expect_equal(5 < x, new_rvar(5 < x_array))
expect_equal(x < y, new_rvar(x_array < y_array))

expect_identical(x <= 5, new_rvar(x_array <= 5))
expect_identical(5 <= x, new_rvar(5 <= x_array))
expect_identical(x <= y, new_rvar(x_array <= y_array))
expect_equal(x <= 5, new_rvar(x_array <= 5))
expect_equal(5 <= x, new_rvar(5 <= x_array))
expect_equal(x <= y, new_rvar(x_array <= y_array))

expect_identical(x > 5, new_rvar(x_array > 5))
expect_identical(5 > x, new_rvar(5 > x_array))
expect_identical(x > y, new_rvar(x_array > y_array))
expect_equal(x > 5, new_rvar(x_array > 5))
expect_equal(5 > x, new_rvar(5 > x_array))
expect_equal(x > y, new_rvar(x_array > y_array))

expect_identical(x >= 5, new_rvar(x_array >= 5))
expect_identical(5 >= x, new_rvar(5 >= x_array))
expect_identical(x >= y, new_rvar(x_array >= y_array))
expect_equal(x >= 5, new_rvar(x_array >= 5))
expect_equal(5 >= x, new_rvar(5 >= x_array))
expect_equal(x >= y, new_rvar(x_array >= y_array))

expect_identical(x == 5, new_rvar(x_array == 5))
expect_identical(5 == x, new_rvar(5 == x_array))
expect_identical(x == y, new_rvar(x_array == y_array))
expect_equal(x == 5, new_rvar(x_array == 5))
expect_equal(5 == x, new_rvar(5 == x_array))
expect_equal(x == y, new_rvar(x_array == y_array))

expect_identical(x != 5, new_rvar(x_array != 5))
expect_identical(5 != x, new_rvar(5 != x_array))
expect_identical(x != y, new_rvar(x_array != y_array))
expect_equal(x != 5, new_rvar(x_array != 5))
expect_equal(5 != x, new_rvar(5 != x_array))
expect_equal(x != y, new_rvar(x_array != y_array))
})

test_that("functions in the Math generic with extra arguments work", {
Expand Down Expand Up @@ -131,7 +131,7 @@ test_that("matrix multiplication works", {
x_array[3,,] %*% y_array[3,,],
x_array[4,,] %*% y_array[4,,]
))
expect_identical(x %**% y, xy_ref)
expect_equal(x %**% y, xy_ref)


x_array = array(1:6, dim = c(2,3))
Expand All @@ -143,20 +143,20 @@ test_that("matrix multiplication works", {
x_array[1,] %*% y_array[1,],
x_array[2,] %*% y_array[2,]
))
expect_identical(x %**% y, xy_ref)
expect_equal(x %**% y, xy_ref)

# automatic promotion to row/col vector of numeric vectors
x_meany_ref = new_rvar(abind::abind(along = 0,
x_array[1,] %*% colMeans(y_array),
x_array[2,] %*% colMeans(y_array)
))
expect_identical(x %**% colMeans(y_array), x_meany_ref)
expect_equal(x %**% colMeans(y_array), x_meany_ref)

meanx_y_ref = new_rvar(abind::abind(along = 0,
colMeans(x_array) %*% y_array[1,],
colMeans(x_array) %*% y_array[2,]
))
expect_identical(colMeans(x_array) %**% y, meanx_y_ref)
expect_equal(colMeans(x_array) %**% y, meanx_y_ref)

# dimension name preservation
m1 <- as_rvar(diag(1:3))
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-rvar-rfun.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_that("rdo works", {
x_array[3,,] %*% y_array[3,,],
x_array[4,,] %*% y_array[4,,]
))
expect_identical(rdo(x %*% y), xy_ref)
expect_equal(rdo(x %*% y), xy_ref)
})

test_that("rfun works", {
Expand All @@ -25,7 +25,7 @@ test_that("rfun works", {
x_array[3,,] %*% y_array[3,,],
x_array[4,,] %*% y_array[4,,]
))
expect_identical(rfun(function(a,b) a %*% b)(x, y), xy_ref)
expect_equal(rfun(function(a,b) a %*% b)(x, y), xy_ref)
})

test_that("rvar_rng works", {
Expand Down
Loading

0 comments on commit befdfbd

Please sign in to comment.