From e758cde72e93302d00370ed98a4b616f075fd683 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 6 Mar 2023 09:20:34 -0500 Subject: [PATCH 1/3] test `get_model_spec()` helper --- tests/testthat/test_translate.R | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/testthat/test_translate.R b/tests/testthat/test_translate.R index 24a50bc5f..d4c5bc913 100644 --- a/tests/testthat/test_translate.R +++ b/tests/testthat/test_translate.R @@ -309,4 +309,38 @@ test_that("translate tuning paramter names", { expect_snapshot_error(.model_param_name_key(1)) }) +# ------------------------------------------------------------------------------ + +test_that("get_model_spec helper", { + mod1 <- get_model_spec("linear_reg", "regression", "lm") + + expect_type(mod1, "list") + + expect_type(mod1$libs, "character") + expect_length(mod1$libs, 1) + expect_equal(mod1$libs, "stats") + + expect_type(mod1$fit, "list") + expect_length(mod1$fit, 4) + expect_equal(names(mod1$fit), c("interface", "protect", "func", "defaults")) + expect_type(mod1$pred, "list") + expect_length(mod1$pred, 4) + expect_equal(names(mod1$pred), c("numeric", "conf_int", "pred_int", "raw")) + + expect_type(mod1$pred$numeric, "list") + expect_length(mod1$pred$numeric, 4) + expect_equal(names(mod1$pred$numeric), c("pre", "post", "func", "args")) + + expect_type(mod1$pred$conf_int, "list") + expect_length(mod1$pred$conf_int, 4) + expect_equal(names(mod1$pred$conf_int), c("pre", "post", "func", "args")) + + expect_type(mod1$pred$pred_int, "list") + expect_length(mod1$pred$pred_int, 4) + expect_equal(names(mod1$pred$pred_int), c("pre", "post", "func", "args")) + + expect_type(mod1$pred$raw, "list") + expect_length(mod1$pred$raw, 4) + expect_equal(names(mod1$pred$raw), c("pre", "post", "func", "args")) +}) From 8ab4cc2ec3bb59a2fa317f099140081928961c7c Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 6 Mar 2023 09:29:35 -0500 Subject: [PATCH 2/3] speed up `get_model_spec()` helper --- R/translate.R | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/R/translate.R b/R/translate.R index 4732dab76..13ac285dd 100644 --- a/R/translate.R +++ b/R/translate.R @@ -106,25 +106,21 @@ get_model_spec <- function(model, mode, engine) { env_obj <- grep(model, env_obj, value = TRUE) res <- list() - res$libs <- - rlang::env_get(m_env, paste0(model, "_pkgs")) %>% - dplyr::filter(engine == !!engine) %>% - purrr::pluck("pkg") %>% - purrr::pluck(1) - - res$fit <- - rlang::env_get(m_env, paste0(model, "_fit")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::pull(value) %>% - purrr::pluck(1) - - pred_code <- - rlang::env_get(m_env, paste0(model, "_predict")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::select(-engine, -mode) - - res$pred <- pred_code[["value"]] - names(res$pred) <- pred_code$type + + libs <- rlang::env_get(m_env, paste0(model, "_pkgs")) + libs <- vctrs::vec_slice(libs$pkg, libs$engine == engine) + res$libs <- libs[[1L]] + + fits <- rlang::env_get(m_env, paste0(model, "_fit")) + fits <- vctrs::vec_slice(fits$value, fits$mode == mode & fits$engine == engine) + res$fit <- fits[[1L]] + + preds <- rlang::env_get(m_env, paste0(model, "_predict")) + where <- preds$mode == mode & preds$engine == engine + types <- vctrs::vec_slice(preds$type, where) + values <- vctrs::vec_slice(preds$value, where) + names(values) <- types + res$pred <- values res } From 88ffb018a30fece9f711a8209a4ad47c289beb3e Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 6 Mar 2023 10:02:20 -0500 Subject: [PATCH 3/3] accommodate 0-length `libs`/`fit` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` r boop <- list(a = 1) bench::mark( base = if (length(boop) > 0) {boop[[1]]} else {NULL}, purrr = purrr::pluck(boop, 1), dplyr = dplyr::first(boop), iterations = 100 ) #> # A tibble: 3 × 6 #> expression min median `itr/sec` mem_alloc `gc/sec` #> #> 1 base 122.94ns 163.91ns 3179808. 0B 0 #> 2 purrr 2.01µs 2.52µs 305755. 22.66KB 0 #> 3 dplyr 5.29µs 6.07µs 141894. 8.12MB 0 ``` Created on 2023-03-06 with [reprex v2.0.2](https://reprex.tidyverse.org) --- R/translate.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/translate.R b/R/translate.R index 13ac285dd..e096b5dfe 100644 --- a/R/translate.R +++ b/R/translate.R @@ -109,11 +109,11 @@ get_model_spec <- function(model, mode, engine) { libs <- rlang::env_get(m_env, paste0(model, "_pkgs")) libs <- vctrs::vec_slice(libs$pkg, libs$engine == engine) - res$libs <- libs[[1L]] + res$libs <- if (length(libs) > 0) {libs[[1]]} else {NULL} fits <- rlang::env_get(m_env, paste0(model, "_fit")) fits <- vctrs::vec_slice(fits$value, fits$mode == mode & fits$engine == engine) - res$fit <- fits[[1L]] + res$fit <- if (length(fits) > 0) {fits[[1]]} else {NULL} preds <- rlang::env_get(m_env, paste0(model, "_predict")) where <- preds$mode == mode & preds$engine == engine