Skip to content

Commit

Permalink
Faster multi_scale (#34)
Browse files Browse the repository at this point in the history
* Make ww_multi_scale() faster

* Update snaps

* Only import necessary functions

* Improve documentation

* Fix NOTE
  • Loading branch information
mikemahoney218 committed Apr 24, 2023
1 parent d5c1b07 commit 4010b96
Show file tree
Hide file tree
Showing 35 changed files with 1,279 additions and 792 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Imports:
hardhat,
Matrix,
purrr,
rlang,
rlang (>= 1.1.0),
sf (>= 1.0-0),
spdep (>= 1.1-9),
stats,
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
`data`. This is a bit faster than passing `SpatRaster` objects to `truth` and
`estimate`, as extraction is only done once per grid rather than twice, but
does not easily support passing R functions to `aggregation_function`.

* The `sf` method for `ww_multi_scale()` is now _much_ faster (and more memory
efficient).

# waywiser 0.3.0

Expand Down
141 changes: 79 additions & 62 deletions R/multi_scale.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#'
#' If `data` is `NULL`, then `truth` and `estimate` should both be `SpatRaster`
#' objects, as created via [terra::rast()]. These rasters will then be
#' aggregated to each grid using [exactextractr::exact_extract()].
#' aggregated to each grid using [exactextractr::exact_extract()]. If `data`
#' is a `SpatRaster` object, then `truth` and `estimate` should be indices to
#' select the appropriate layers of the raster via [terra::subset()].
#'
#' Grids are calculated using the bounding box of `truth`, under the assumption
#' that you may have extrapolated into regions which do not have matching "true"
Expand All @@ -31,6 +33,14 @@
#' then `cellsize` will be automatically adjusted to create the requested
#' number of cells.
#'
#' Grids are created by mapping over each argument passed via `...`
#' simultaneously, in a similar manner to [mapply()] or [purrr::pmap()]. This
#' means that, for example, passing `n = list(c(1, 2))` will create a single
#' 1x2 grid, while passing `n = c(1, 2)` will create a 1x1 grid _and_ a 2x2
#' grid. It also means that arguments will be recycled using R's standard
#' vector recycling rules, so that passing `n = c(1, 2)` and `square = FALSE`
#' will create two separate grids of hexagons.
#'
#' This function can be used for geographic or projected coordinate reference
#' systems and expects 2D data.
#'
Expand Down Expand Up @@ -380,93 +390,100 @@ ww_multi_scale.sf <- function(

grid_list <- handle_grids(data, grids, autoexpand_grid, ...)

grid_list$grid_intersections <- purrr::map(
data$.grid_idx <- seq_len(nrow(data))
out <- purrr::map2_dfr(
grid_list$grids,
function(grid) {
out <- sf::st_intersects(grid, data)
out[purrr::map_lgl(out, function(x) !identical(x, integer(0)))]
}
)
grid_list$grid_arg_idx,
function(grid, grid_args_idx) {
grid_args <- grid_list[["grid_args"]][grid_args_idx, ]

.notes <- purrr::map(
grid_list$grid_intersections,
function(idx) {
missing <- setdiff(seq_len(nrow(data)), unlist(idx))
grid <- sf::st_as_sf(grid)

note <- character(0)
data_crs <- sf::st_crs(data)
grid_crs <- sf::st_crs(grid)
# If both have a CRS, reproject
if (!is.na(data_crs) && !is.na(grid_crs)) {
grid <- sf::st_transform(grid, data_crs)
# if only data has CRS, assume grid in same
} else if (!is.na(data_crs)) {
grid <- sf::st_set_crs(grid, data_crs)
}
# if neither has a CRS, ignore (so, implicitly assume grid is in same)

grid$grid_cell_idx <- seq_len(nrow(grid))
grid_matches <- sf::st_join(
grid,
data[".grid_idx"],
left = FALSE
)
grid_matches <- sf::st_drop_geometry(grid_matches)

missing <- setdiff(data[[".grid_idx"]], grid_matches[[".grid_idx"]])
note <- character(0)
if (length(missing) > 0) {
note <- "Some observations were not within any grid cell, and as such were not used in any assessments. Their row numbers are in the `missing_indices` column."
missing <- list(missing)
} else {
missing <- list()
}

tibble::tibble(
notes_tibble <- tibble::tibble(
note = note,
missing_indices = missing
)
}
)

if (any(purrr::map_lgl(.notes, function(x) nrow(x) > 0))) {
rlang::warn(
c(
"Some observations were not within any grid cell, and as such were not used in any assessments.",
i = "See the `.notes` column for details."
matched_data <- dplyr::left_join(
data,
grid_matches,
by = dplyr::join_by(.grid_idx)
)
)
}

grid_list$grid_intersections <- purrr::map(
grid_list$grid_intersections,
function(idx_list) {
out <- purrr::map_dfr(
idx_list,
function(idx) {
dplyr::summarise(
data[idx, , drop = FALSE],
.truth = rlang::exec(.env[["aggregation_function"]], {{ truth }}),
.truth_count = sum(!is.na({{ truth }})),
.estimate = rlang::exec(.env[["aggregation_function"]], {{ estimate }}),
.estimate_count = sum(!is.na({{ estimate }})),
.groups = "keep"
)
}
matched_data <- sf::st_drop_geometry(matched_data)
matched_data <- matched_data[!is.na(matched_data[["grid_cell_idx"]]), ]
matched_data <- dplyr::group_by(
matched_data,
dplyr::across(dplyr::all_of(c(dplyr::group_vars(data), "grid_cell_idx")))
)
matched_data <- dplyr::summarise(
matched_data,
.truth = rlang::exec(.env[["aggregation_function"]], {{ truth }}),
.truth_count = sum(!is.na({{ truth }})),
.estimate = rlang::exec(.env[["aggregation_function"]], {{ estimate }}),
.estimate_count = sum(!is.na({{ estimate }})),
.groups = "drop"
)

if (dplyr::is_grouped_df(data)) {
dplyr::group_by(out, !!!dplyr::groups(data))
} else {
out
matched_data <- dplyr::group_by(matched_data, !!!dplyr::groups(data))
}
}
)

purrr::pmap_dfr(
list(
dat = grid_list$grid_intersections,
grid = grid_list$grids,
grid_arg = grid_list$grid_arg_idx,
.notes = .notes
),
function(dat, grid, grid_arg, .notes) {
out <- metrics(dat, .truth, .estimate, na_rm = na_rm)
out <- metrics(matched_data, .truth, .estimate, na_rm = na_rm)
out["grid_cell_idx"] <- NULL
out[attr(out, "sf_column")] <- NULL
out$.grid_args <- list(grid_list$grid_args[grid_arg, ])
out$.grid <- list(
suppressMessages( # We want to ignore a "names repair" message here
sf::st_join(
sf::st_as_sf(grid),
dat,
sf::st_contains
)
)
out$.grid_args <- list(grid_args)
.grid <- dplyr::left_join(
grid,
matched_data,
by = dplyr::join_by(grid_cell_idx)
)
out$.notes <- list(.notes)
.grid["grid_cell_idx"] <- NULL
out$.grid <- list(.grid)
out$.notes <- list(notes_tibble)
out

}
)

if (any(purrr::map_lgl(out[[".notes"]], function(x) nrow(x) > 0))) {
rlang::warn(
c(
"Some observations were not within any grid cell, and as such were not used in any assessments.",
i = "See the `.notes` column for details."
)
)
}

out

}

handle_metrics <- function(metrics) {
Expand Down
2 changes: 1 addition & 1 deletion R/waywiser-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
#' @importFrom rlang enquo .data .env
#' @importFrom stats predict complete.cases na.fail
#' @importFrom yardstick new_numeric_metric
utils::globalVariables(c(".truth", ".estimate"))
utils::globalVariables(c(".truth", ".estimate", ".grid_idx", "grid_cell_idx"))
## usethis namespace: end
NULL
25 changes: 19 additions & 6 deletions man/ww_multi_scale.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 25 additions & 16 deletions tests/testthat/_snaps/area_of_applicability.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

Code
ww_area_of_applicability(y ~ ., train, test, importance)
Error <rlang_error>
Missing values in training data.
Condition
Error in `create_aoa()`:
! Missing values in training data.
i Either process your data to fix NA values, or set `na_rm = TRUE`.

---
Expand All @@ -26,8 +27,9 @@

Code
ww_area_of_applicability(train[2:11], test[2:11], importance)
Error <rlang_error>
Missing values in training data.
Condition
Error in `create_aoa()`:
! Missing values in training data.
i Either process your data to fix NA values, or set `na_rm = TRUE`.

---
Expand All @@ -45,8 +47,9 @@
Code
ww_area_of_applicability(as.matrix(train[2:11]), as.matrix(test[2:11]),
importance)
Error <rlang_error>
Missing values in training data.
Condition
Error in `create_aoa()`:
! Missing values in training data.
i Either process your data to fix NA values, or set `na_rm = TRUE`.

---
Expand All @@ -64,7 +67,8 @@

Code
ww_area_of_applicability(comb_rset_no_y, importance = importance)
Error <purrr_error_indexed>
Condition
Error in `purrr::map()`:
i In index: 1.
Caused by error in `create_aoa()`:
! Missing values in training data.
Expand All @@ -84,7 +88,8 @@

Code
ww_area_of_applicability(comb_rset, recipes::recipe(y ~ ., train), importance = importance)
Error <purrr_error_indexed>
Condition
Error in `purrr::map()`:
i In index: 1.
Caused by error in `create_aoa()`:
! Missing values in training data.
Expand Down Expand Up @@ -126,29 +131,33 @@

Code
ww_area_of_applicability(y ~ ., train, test[1:10], importance)
Error <rlang_error>
Some columns in `training` were not present in `testing` (or `new_data`).
Condition
Error in `check_di_testing()`:
! Some columns in `training` were not present in `testing` (or `new_data`).

---

Code
ww_area_of_applicability(y ~ ., train, test, na_rm = c(TRUE, FALSE), importance)
Error <rlang_error>
Only one value can be passed to `na_rm`.
Condition
Error in `create_aoa()`:
! Only one value can be passed to `na_rm`.

---

Code
ww_area_of_applicability(y ~ ., train, test, head(importance, -1))
Error <rlang_error>
All predictors must have an importance value in `importance`.
Condition
Error in `ww_area_of_applicability()`:
! All predictors must have an importance value in `importance`.

---

Code
ww_area_of_applicability(y ~ ., train[1:10], test[1:10], importance)
Error <rlang_error>
All variables with an importance value in `importance` must be included as predictors.
Condition
Error in `ww_area_of_applicability()`:
! All variables with an importance value in `importance` must be included as predictors.

# normal use

Expand Down

0 comments on commit 4010b96

Please sign in to comment.