From f70bb95237447163034254399594d92aabc636e2 Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Wed, 24 Nov 2021 11:10:56 -0500 Subject: [PATCH 1/3] Add `expr_contains()` for detecting `:` more robustly --- R/blueprint-formula-default.R | 26 +++++++++++++++++++++++--- tests/testthat/test-mold-formula.R | 5 +++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/R/blueprint-formula-default.R b/R/blueprint-formula-default.R index 9418c6ac..b8b1dc36 100644 --- a/R/blueprint-formula-default.R +++ b/R/blueprint-formula-default.R @@ -883,10 +883,11 @@ detect_interactions <- function(.formula) { return(character(0)) } - terms_nms <- colnames(terms_matrix) + terms_names <- colnames(terms_matrix) # All interactions (*, ^, %in%) will be expanded to `:` - has_interactions <- grepl(":", terms_nms) + terms_exprs <- rlang::parse_exprs(terms_names) + has_interactions <- map_lgl(terms_exprs, expr_contains, what = as.name(":")) has_any_interactions <- any(has_interactions) @@ -894,11 +895,30 @@ detect_interactions <- function(.formula) { return(character(0)) } - bad_terms <- terms_nms[has_interactions] + bad_terms <- terms_names[has_interactions] bad_terms } +expr_contains <- function(expr, what) { + if (!rlang::is_expression(expr)) { + rlang::abort("`expr` must be an expression.") + } + if (!rlang::is_symbol(what)) { + rlang::abort("`what` must be a symbol.") + } + + expr_contains_recurse(expr, what) +} +expr_contains_recurse <- function(expr, what) { + switch ( + typeof(expr), + symbol = identical(expr, what), + language = any(map_lgl(expr, expr_contains_recurse, what = what)), + FALSE + ) +} + extract_original_factorish_names <- function(ptype) { where_factorish <- vapply(ptype, is_factorish, logical(1)) diff --git a/tests/testthat/test-mold-formula.R b/tests/testthat/test-mold-formula.R index b978b812..224e06ec 100644 --- a/tests/testthat/test-mold-formula.R +++ b/tests/testthat/test-mold-formula.R @@ -503,6 +503,11 @@ test_that("LHS of the formula cannot contain interactions", { }) +test_that("LHS of the formula won't misinterpret `::` as an interaction (#174)", { + out <- mold(base::cbind(num_1, num_2) ~ num_3, example_train) + expect_identical(ncol(out$outcomes), 2L) +}) + test_that("original predictor and outcome classes are recorded", { bp <- default_formula_blueprint(composition = "dgCMatrix") From c31dd81da906ef84b32f454fc9c65fa2a9394a87 Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Wed, 24 Nov 2021 11:11:52 -0500 Subject: [PATCH 2/3] Also use `expr_contains()` for RHS factorish interactions I couldn't actually come up with a case to test for this, but it seems more consistent and robust anyways --- R/blueprint-formula-default.R | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/blueprint-formula-default.R b/R/blueprint-formula-default.R index b8b1dc36..f40a1edd 100644 --- a/R/blueprint-formula-default.R +++ b/R/blueprint-formula-default.R @@ -836,14 +836,16 @@ detect_factorish_in_interactions <- function(.terms, .factorish_names) { # In the factor matrix, only `:` is present to represent interactions, # even if something like * or ^ or %in% was used to generate it - where_interactions <- grepl(":", colnames(factorish_rows)) + terms_names <- colnames(factorish_rows) + terms_exprs <- rlang::parse_exprs(terms_names) + has_interactions <- map_lgl(terms_exprs, expr_contains, what = as.name(":")) - none_have_interactions <- !any(where_interactions) + none_have_interactions <- !any(has_interactions) if (none_have_interactions) { return(character(0)) } - interaction_cols <- factorish_rows[, where_interactions, drop = FALSE] + interaction_cols <- factorish_rows[, has_interactions, drop = FALSE] factorish_is_bad_if_gt_0 <- rowSums(interaction_cols) bad_factorish_vals <- factorish_is_bad_if_gt_0[factorish_is_bad_if_gt_0 > 0] From ac9320594ed401e92f7eb424d3681c0400af026b Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Wed, 24 Nov 2021 11:12:07 -0500 Subject: [PATCH 3/3] NEWS bullet --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 8ea9afc1..4c722dc4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # hardhat (development version) +* `mold()` no longer misinterprets `::` as an interaction term (#174). + * Added `extract_parameter_dials()` and `extract_parameter_set_dials()` generics to extend the family of `extract_*()` generics.