From 56a202862f0a5b72d7223d4024d2e5aeb7143701 Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Wed, 20 Apr 2022 14:33:30 -0400 Subject: [PATCH 1/4] Prefix everywhere we use `new_quosure()` or `empty_env()` We don't import these, so we have to do this. Tests were only working by chance because we have `library(rlang)` in some of the test files! --- tests/testthat/helpers.R | 2 +- tests/testthat/test_svm_liquidsvm.R | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 5b82da4e5..f37330021 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -5,7 +5,7 @@ # need this wrapper. new_empty_quosure <- function(expr) { - new_quosure(expr, env = empty_env()) + rlang::new_quosure(expr, env = rlang::empty_env()) } tune_check <- function() { diff --git a/tests/testthat/test_svm_liquidsvm.R b/tests/testthat/test_svm_liquidsvm.R index cef99ec9d..0739e7f35 100644 --- a/tests/testthat/test_svm_liquidsvm.R +++ b/tests/testthat/test_svm_liquidsvm.R @@ -52,10 +52,10 @@ test_that('engine arguments', { expected = list( x = expr(missing_arg()), y = expr(missing_arg()), - scale = new_quosure(FALSE, env = empty_env()), - predict.prob = new_quosure(TRUE, env = empty_env()), - threads = new_quosure(2, env = empty_env()), - gpus = new_quosure(1, env = empty_env()), + scale = rlang::new_quosure(FALSE, env = rlang::empty_env()), + predict.prob = rlang::new_quosure(TRUE, env = rlang::empty_env()), + threads = rlang::new_quosure(2, env = rlang::empty_env()), + gpus = rlang::new_quosure(1, env = rlang::empty_env()), folds = 1 ) ) From 6fad709960ed26ccf1842a3a46d784342ad494a4 Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Wed, 20 Apr 2022 14:38:59 -0400 Subject: [PATCH 2/4] Ensure that `fit_xy()` patches the formula environment with weights --- R/case_weights.R | 18 +++++++++ R/convert_data.R | 6 +++ R/fit.R | 14 ++----- R/fit_helpers.R | 2 +- tests/testthat/test-case-weights.R | 61 ++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 11 deletions(-) diff --git a/R/case_weights.R b/R/case_weights.R index 92b759500..2eb8e6bbb 100644 --- a/R/case_weights.R +++ b/R/case_weights.R @@ -47,6 +47,24 @@ weights_to_numeric <- function(x, spec) { x } +patch_formula_environment_with_case_weights <- function(formula, + data, + case_weights) { + # `lm()` and `glm()` and others use the original model function call to + # construct a call for `model.frame()`. That will normally fail because the + # formula has its own environment attached (usually the global environment) + # and it will look there for a vector named 'weights'. To account + # for this, we create a child of the `formula`'s environment and + # stash the `weights` there with the expected name and then + # reassign this as the `formula`'s environment + environment(formula) <- rlang::new_environment( + data = list(data = data, weights = case_weights), + parent = environment(formula) + ) + + formula +} + #' Convert case weights to final from #' #' tidymodels requires case weights to have special classes. To use them in diff --git a/R/convert_data.R b/R/convert_data.R index 7af0c34f3..ef8fa0673 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -252,6 +252,12 @@ if (length(weights) != nrow(x)) { rlang::abort(glue::glue("`weights` should have {nrow(x)} elements")) } + + form <- patch_formula_environment_with_case_weights( + formula = form, + data = x, + case_weights = weights + ) } res <- list( diff --git a/R/fit.R b/R/fit.R index 93f9c4b7e..6cda2e2c0 100644 --- a/R/fit.R +++ b/R/fit.R @@ -146,16 +146,10 @@ fit.model_spec <- wts <- weights_to_numeric(case_weights, object) - # `lm()` and `glm()` and others use the original model function call to - # construct a call for `model.frame()`. That will normally fail because the - # formula has its own environment attached (usually the global environment) - # and it will look there for a vector named 'weights'. To account - # for this, we create a child of the `formula`'s environment and - # stash the `weights` there with the expected name and then - # reassign this as the `formula`'s environment - environment(formula) <- rlang::new_environment( - data = list(data = data, weights = wts), - parent = environment(formula) + formula <- patch_formula_environment_with_case_weights( + formula = formula, + data = data, + case_weights = wts ) eval_env$data <- data diff --git a/R/fit_helpers.R b/R/fit_helpers.R index eae54f9b4..d4fbdf6b8 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -177,7 +177,7 @@ xy_form <- function(object, env, control, ...) { .convert_xy_to_form_fit( x = env$x, y = env$y, - weights = NULL, + weights = env$weights, y_name = "..y", remove_intercept = remove_intercept ) diff --git a/tests/testthat/test-case-weights.R b/tests/testthat/test-case-weights.R index 4f83d9976..f2bce3656 100644 --- a/tests/testthat/test-case-weights.R +++ b/tests/testthat/test-case-weights.R @@ -24,6 +24,25 @@ test_that('case weights with xy method', { print(C5_bst_wt_fit$fit$call), "weights = weights" ) + + expect_error({ + set.seed(1) + C5_bst_wt_fit <- + boost_tree(trees = 5) %>% + set_engine("C5.0") %>% + set_mode("classification") %>% + fit_xy( + x = two_class_dat[c("A", "B")], + y = two_class_dat$Class, + case_weights = wts + ) + }, + regexp = NA) + + expect_output( + print(C5_bst_wt_fit$fit$call), + "weights = weights" + ) }) @@ -51,6 +70,19 @@ test_that('case weights with xy method - non-standard argument names', { # print(rf_wt_fit$fit$call), # "case\\.weights = weights" # ) + + expect_error({ + set.seed(1) + rf_wt_fit <- + rand_forest(trees = 5) %>% + set_mode("classification") %>% + fit_xy( + x = two_class_dat[c("A", "B")], + y = two_class_dat$Class, + case_weights = wts + ) + }, + regexp = NA) }) test_that('case weights with formula method', { @@ -78,5 +110,34 @@ test_that('case weights with formula method', { expect_equal(coef(lm_wt_fit$fit), coef(lm_sub_fit$fit)) }) +test_that('case weights with formula method that goes through `fit_xy()`', { + + skip_if_not_installed("modeldata") + data("ames", package = "modeldata") + ames$Sale_Price <- log10(ames$Sale_Price) + + set.seed(1) + wts <- runif(nrow(ames)) + wts <- ifelse(wts < 1/5, 0L, 1L) + ames_subset <- ames[wts != 0, ] + wts <- frequency_weights(wts) + + expect_error( + lm_wt_fit <- + linear_reg() %>% + fit_xy( + x = ames[c("Longitude", "Latitude")], + y = ames$Sale_Price, + case_weights = wts + ), + regexp = NA) + lm_sub_fit <- + linear_reg() %>% + fit_xy( + x = ames_subset[c("Longitude", "Latitude")], + y = ames_subset$Sale_Price + ) + expect_equal(coef(lm_wt_fit$fit), coef(lm_sub_fit$fit)) +}) From 56e75ca367141d27390bbf3a95902831fa92cd66 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 21 Apr 2022 10:02:53 -0400 Subject: [PATCH 3/4] missing roxygen tag --- R/case_weights.R | 1 + man/convert_case_weights.Rd | 2 ++ 2 files changed, 3 insertions(+) diff --git a/R/case_weights.R b/R/case_weights.R index 2eb8e6bbb..a4c60bc9c 100644 --- a/R/case_weights.R +++ b/R/case_weights.R @@ -73,6 +73,7 @@ patch_formula_environment_with_case_weights <- function(formula, #' @param x A vector with class `"hardhat_case_weights"`. #' @param where The location where they will be used: `"parsnip"` or #' `"yardstick"`. +#' @param ... Additional options (not currently used). #' @return A numeric vector or NULL. #' @export convert_case_weights <- function(x, where = "parsnip", ...) { diff --git a/man/convert_case_weights.Rd b/man/convert_case_weights.Rd index 015e0c5cf..619770eb1 100644 --- a/man/convert_case_weights.Rd +++ b/man/convert_case_weights.Rd @@ -17,6 +17,8 @@ convert_case_weights(x, where = "parsnip", ...) \item{where}{The location where they will be used: \code{"parsnip"} or \code{"yardstick"}.} + +\item{...}{Additional options (not currently used).} } \value{ A numeric vector or NULL. From d153bc254d2e1787c06d29108710c90ff168e1d3 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 21 Apr 2022 12:29:28 -0400 Subject: [PATCH 4/4] avoid deprecated tests --- tests/testthat/test_mlp.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index 0568236a3..60294b953 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -11,6 +11,7 @@ source("helpers.R") test_that('primary arguments', { + skip("reworked in target branch") hidden_units <- mlp(mode = "regression", hidden_units = 4) hidden_units_nnet <- translate(hidden_units %>% set_engine("nnet")) hidden_units_keras <- translate(hidden_units %>% set_engine("keras")) @@ -87,6 +88,7 @@ test_that('primary arguments', { }) test_that('engine arguments', { + skip("reworked in target branch") nnet_hess <- mlp(mode = "classification") %>% set_engine("nnet", Hess = TRUE) expect_equal(translate(nnet_hess)$method$fit$args, list(