Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use check_name() in all steps that produces new columns #1124

Merged
merged 41 commits into from Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
807a84b
use check_name() in step_geodist()
EmilHvitfeldt Mar 30, 2023
1e2bc07
Merged origin/main into use-check_name
EmilHvitfeldt Apr 4, 2023
2e38e1f
properly use check_name() in step_geodist()
EmilHvitfeldt Apr 5, 2023
b6224e3
use check_name() in step_bs()
EmilHvitfeldt Apr 5, 2023
20afd59
test check_name() is used in step_classdist()
EmilHvitfeldt Apr 5, 2023
50ce987
use check_name() in step_date()
EmilHvitfeldt Apr 5, 2023
3f5de94
use check_name() in step_dummy()
EmilHvitfeldt Apr 5, 2023
ca0bab0
use check_name() in step_dummy_multi_choice()
EmilHvitfeldt Apr 5, 2023
4b523ba
use check_name() in step_dummy_extract()
EmilHvitfeldt Apr 5, 2023
5742e28
use check_name() in step_harmonic()
EmilHvitfeldt Apr 5, 2023
1f362e5
use check_name()
EmilHvitfeldt Apr 5, 2023
083a53c
use check_name() in step_ica()
EmilHvitfeldt Apr 5, 2023
12dc925
use check_name() in step_interact()
EmilHvitfeldt Apr 5, 2023
ef6c108
test check_name() usage in step_isomap()
EmilHvitfeldt Apr 5, 2023
8a86731
test check_name() usage in step_kpca()
EmilHvitfeldt Apr 5, 2023
4a07cce
test usage of check_name() in step_kpca_poly()
EmilHvitfeldt Apr 5, 2023
d90b257
test check_name() usage in step_kpca_rbf()
EmilHvitfeldt Apr 5, 2023
438bf8d
use check_name() in step_indicate_na()
EmilHvitfeldt Apr 5, 2023
63ef3ff
use check_name() in step_nnmf()
EmilHvitfeldt Apr 5, 2023
08e28e3
use check_name() in step_nnmf_sparse()
EmilHvitfeldt Apr 5, 2023
97b2499
use check_name() in step_ns()
EmilHvitfeldt Apr 5, 2023
934e717
test check_name() in step_pca()
EmilHvitfeldt Apr 5, 2023
c1713fb
use check_name() in step_pls()
EmilHvitfeldt Apr 5, 2023
e71da52
use check_name() in step_poly()
EmilHvitfeldt Apr 5, 2023
ea178e9
use check_name() in step_poly_bernstein()
EmilHvitfeldt Apr 5, 2023
ac2995a
use check_name() in step_ratio()
EmilHvitfeldt Apr 5, 2023
394553f
use check_name() in step_spline_b()
EmilHvitfeldt Apr 5, 2023
178c57d
use check_name() in step_spline_monotone()
EmilHvitfeldt Apr 5, 2023
37699eb
use check_name() in step_spline_natural()
EmilHvitfeldt Apr 5, 2023
84ac2d8
use check_name() in step_spline_nonnegative()
EmilHvitfeldt Apr 5, 2023
48b5fe4
use check_name() in step_time()
EmilHvitfeldt Apr 5, 2023
4776903
use cli in check_name()
EmilHvitfeldt Apr 5, 2023
92ddcc5
update test for check_name() and step_interact()
EmilHvitfeldt Apr 5, 2023
e790a87
add call argument to check_name()
EmilHvitfeldt Apr 5, 2023
7eda127
update snapshots
EmilHvitfeldt Apr 5, 2023
f9cd48f
Merged origin/main into use-check_name
EmilHvitfeldt Apr 5, 2023
c42d933
update news
EmilHvitfeldt Apr 5, 2023
08db1c2
use check_name() in step_intercept()
EmilHvitfeldt Apr 5, 2023
53b2511
use check_name() in step_spline_convex()
EmilHvitfeldt Apr 5, 2023
fa6fe01
skip tests
EmilHvitfeldt Apr 5, 2023
54b4663
more skip_if_not_installed()
EmilHvitfeldt Apr 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Expand Up @@ -8,6 +8,8 @@

* Steps with tunable arguments now have those arguments listed in the documentation.

* All steps that add new columns will now informatively error if name collision occurs. (#983)

# recipes 1.0.5

* Added `outside` argument to `step_percentile()` to determine different ways of handling values outside the range of the training data.
Expand Down
5 changes: 4 additions & 1 deletion R/bs.R
Expand Up @@ -183,7 +183,10 @@ bake.step_bs <- function(object, new_data, ...) {
strt <- max(cols) + 1
new_data[, orig_var] <- NULL
}
new_data <- bind_cols(new_data, as_tibble(bs_values))
bs_values <- as_tibble(bs_values)
bs_values <- check_name(bs_values, new_data, object, names(bs_values))

new_data <- bind_cols(new_data, bs_values)
new_data
}

Expand Down
2 changes: 2 additions & 0 deletions R/date.R
Expand Up @@ -286,6 +286,8 @@ bake.step_date <- function(object, new_data, ...) {

names(date_values) <- new_names

date_values <- check_name(date_values, new_data, object, names(date_values))

new_data <- bind_cols(new_data, date_values)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
Expand Down
5 changes: 4 additions & 1 deletion R/dummy.R
Expand Up @@ -331,7 +331,10 @@ bake.step_dummy <- function(object, new_data, ...) {
## use backticks for nonstandard factor levels here
used_lvl <- gsub(paste0("^\\`?", col_names[i], "\\`?"), "", colnames(indicators))
colnames(indicators) <- object$naming(col_names[i], used_lvl, fac_type == "ordered")
new_data <- bind_cols(new_data, as_tibble(indicators))
indicators <- as_tibble(indicators)
indicators <- check_name(indicators, new_data, object, names(indicators))

new_data <- bind_cols(new_data, indicators)
if (any(!object$preserve, !keep_original_cols)) {
new_data[, col_names[i]] <- NULL
}
Expand Down
5 changes: 4 additions & 1 deletion R/dummy_multi_choice.R
Expand Up @@ -176,8 +176,11 @@ bake.step_dummy_multi_choice <- function(object, new_data, ...) {

used_lvl <- gsub(paste0("^", prefix), "", colnames(indicators))
colnames(indicators) <- object$naming(prefix, used_lvl)
indicators <- as_tibble(indicators)

new_data <- bind_cols(new_data, as_tibble(indicators))
indicators <- check_name(indicators, new_data, object, names(indicators))

new_data <- bind_cols(new_data, indicators)
keep_original_cols <- get_keep_original_cols(object)

if (!keep_original_cols) {
Expand Down
2 changes: 2 additions & 0 deletions R/extract.R
Expand Up @@ -246,6 +246,8 @@ bake.step_dummy_extract <- function(object, new_data, ...) {
used_lvl <- gsub(paste0("^", col_names[i]), "", colnames(indicators))
colnames(indicators) <- object$naming(col_names[i], used_lvl)

indicators <- check_name(indicators, new_data, object, names(indicators))

new_data <- bind_cols(new_data, indicators)

if (!keep_original_cols) {
Expand Down
16 changes: 8 additions & 8 deletions R/geodist.R
Expand Up @@ -157,11 +157,6 @@ prep.step_geodist <- function(x, training, info = NULL, ...) {
}
check_type(training[, lat_name], types = c("double", "integer"))


if (any(names(training) == x$name)) {
rlang::abort("'", x$name, "' is already used in the data.")
}

step_geodist_new(
lon = x$lon,
lat = x$lat,
Expand Down Expand Up @@ -238,11 +233,16 @@ bake.step_geodist <- function(object, new_data, ...) {
}

if (object$log) {
new_data[, object$name] <- log(dist_vals)
} else {
new_data[, object$name] <- dist_vals
dist_vals <- log(dist_vals)
}

geo_data <- tibble(dist_vals)
names(geo_data) <- object$name

geo_data <- check_name(geo_data, new_data, object, newname = object$name)

new_data <- bind_cols(new_data, geo_data)

new_data
}

Expand Down
3 changes: 3 additions & 0 deletions R/harmonic.R
Expand Up @@ -312,6 +312,9 @@ bake.step_harmonic <- function(object, new_data, ...) {
seq_len(n_frequency)
)
res <- as_tibble(res)

res <- check_name(res, new_data, object, names(res))

new_data <- bind_cols(new_data, res)
}

Expand Down
1 change: 1 addition & 0 deletions R/holiday.R
Expand Up @@ -141,6 +141,7 @@ bake.step_holiday <- function(object, new_data, ...) {
names(tmp) <- paste(object$columns[i], names(tmp), sep = "_")
tmp <- purrr::map_dfc(tmp, vec_cast, integer())

tmp <- check_name(tmp, new_data, object, names(tmp))
new_data <- bind_cols(new_data, tmp)
}

Expand Down
4 changes: 3 additions & 1 deletion R/ica.R
Expand Up @@ -199,7 +199,9 @@ bake.step_ica <- function(object, new_data, ...) {
comps <- comps %*% object$res$K %*% object$res$W
comps <- comps[, seq_len(object$num_comp), drop = FALSE]
colnames(comps) <- names0(ncol(comps), object$prefix)
new_data <- bind_cols(new_data, as_tibble(comps))
comps <- as_tibble(comps)
comps <- check_name(comps, new_data, object)
new_data <- bind_cols(new_data, comps)
keep_original_cols <- get_keep_original_cols(object)

if (!keep_original_cols) {
Expand Down
4 changes: 3 additions & 1 deletion R/interactions.R
Expand Up @@ -232,7 +232,9 @@ bake.step_interact <- function(object, new_data, ...) {
}
colnames(out) <-
gsub(":", object$sep, unlist(lapply(res, colnames)))
new_data <- bind_cols(new_data, as_tibble(out))
out <- as_tibble(out)
out <- check_name(out, new_data, object, names(out))
new_data <- bind_cols(new_data, out)
new_data
}

Expand Down
5 changes: 4 additions & 1 deletion R/intercept.R
Expand Up @@ -92,7 +92,10 @@ prep.step_intercept <- function(x, training, info = NULL, ...) {

#' @export
bake.step_intercept <- function(object, new_data, ...) {
tibble::add_column(new_data, !!object$name := object$value, .before = TRUE)
intercept <- tibble(!!object$name := rep(object$value, nrow(new_data)))
intercept <- check_name(intercept, new_data, object, names(intercept))
new_data <- bind_cols(intercept, new_data)
new_data
}

print.step_intercept <-
Expand Down
3 changes: 2 additions & 1 deletion R/isomap.R
Expand Up @@ -204,8 +204,9 @@ bake.step_isomap <- function(object, new_data, ...) {
)@data
})
comps <- comps[, seq_len(object$num_terms), drop = FALSE]
comps <- as_tibble(comps)
comps <- check_name(comps, new_data, object)
new_data <- bind_cols(new_data, as_tibble(comps))
new_data <- bind_cols(new_data, comps)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
new_data <- new_data[, !(colnames(new_data) %in% isomap_vars), drop = FALSE]
Expand Down
3 changes: 2 additions & 1 deletion R/kpca.R
Expand Up @@ -166,8 +166,9 @@ bake.step_kpca <- function(object, new_data, ...) {
comps <- rlang::eval_tidy(cl)
comps <- comps[, seq_len(object$num_comp), drop = FALSE]
colnames(comps) <- names0(ncol(comps), object$prefix)
comps <- as_tibble(comps)
comps <- check_name(comps, new_data, object)
new_data <- bind_cols(new_data, as_tibble(comps))
new_data <- bind_cols(new_data, comps)
keep_original_cols <- get_keep_original_cols(object)

if (!keep_original_cols) {
Expand Down
3 changes: 2 additions & 1 deletion R/kpca_poly.R
Expand Up @@ -172,8 +172,9 @@ bake.step_kpca_poly <- function(object, new_data, ...) {
comps <- rlang::eval_tidy(cl)
comps <- comps[, seq_len(object$num_comp), drop = FALSE]
colnames(comps) <- names0(ncol(comps), object$prefix)
comps <- as_tibble(comps)
comps <- check_name(comps, new_data, object)
new_data <- bind_cols(new_data, as_tibble(comps))
new_data <- bind_cols(new_data, comps)
keep_original_cols <- get_keep_original_cols(object)

if (!keep_original_cols) {
Expand Down
3 changes: 2 additions & 1 deletion R/kpca_rbf.R
Expand Up @@ -160,8 +160,9 @@ bake.step_kpca_rbf <- function(object, new_data, ...) {
comps <- rlang::eval_tidy(cl)
comps <- comps[, seq_len(object$num_comp), drop = FALSE]
colnames(comps) <- names0(ncol(comps), object$prefix)
comps <- as_tibble(comps)
comps <- check_name(comps, new_data, object)
new_data <- bind_cols(new_data, as_tibble(comps))
new_data <- bind_cols(new_data, comps)
keep_original_cols <- get_keep_original_cols(object)

if (!keep_original_cols) {
Expand Down
21 changes: 12 additions & 9 deletions R/misc.R
Expand Up @@ -500,24 +500,27 @@ simple_terms <- function(x, ...) {
#' in the trained object.
#' @param names A logical determining if the names should be set using
#' the names function (TRUE) or colnames function (FALSE).
#' @param call The execution environment of a currently running function, e.g.
#' `caller_env()`. The function will be mentioned in error messages as the
#' source of the error. See the call argument of [rlang::abort()] for more
#' information.
#' @export
#' @keywords internal
check_name <- function(res, new_data, object, newname = NULL, names = FALSE) {
check_name <- function(res, new_data, object, newname = NULL, names = FALSE,
call = caller_env()) {
if (is.null(newname)) {
newname <- names0(ncol(res), object$prefix)
}
new_data_names <- colnames(new_data)
intersection <- new_data_names %in% newname
if (any(intersection)) {
rlang::abort(
paste0(
"Name collision occured in `",
class(object)[1],
"`. The following variable names already exists: ",
paste0(new_data_names[intersection], collapse = ", "),
"."
)
nms <- new_data_names[intersection]
cli::cli_abort(
c("Name collision occured. The following variable names already exists:",
i = " {nms}"),
call = call
)

}
if (names) {
names(res) <- newname
Expand Down
4 changes: 3 additions & 1 deletion R/naindicate.R
Expand Up @@ -111,7 +111,9 @@ bake.step_indicate_na <- function(object, new_data, ...) {
cols <- tibble::new_tibble(cols, nrow = nrow(new_data))
cols <- dplyr::rename_with(cols, ~ vec_paste0(object$prefix, "_", .x))

new_data <- dplyr::bind_cols(new_data, cols)
cols <- check_name(cols, new_data, object, names(cols))

new_data <- bind_cols(new_data, cols)
new_data
}

Expand Down
4 changes: 3 additions & 1 deletion R/nnmf.R
Expand Up @@ -188,7 +188,9 @@ bake.step_nnmf <- function(object, new_data, ...) {
object$res@apply(dimred_data(new_data[, nnmf_vars, drop = FALSE]))@data
comps <- comps[, seq_len(object$num_comp), drop = FALSE]
colnames(comps) <- names0(ncol(comps), object$prefix)
new_data <- bind_cols(new_data, as_tibble(comps))
comps <- as_tibble(comps)
comps <- check_name(comps, new_data, object)
new_data <- bind_cols(new_data, comps)
keep_original_cols <- get_keep_original_cols(object)

if (!keep_original_cols) {
Expand Down
4 changes: 3 additions & 1 deletion R/nnmf_sparse.R
Expand Up @@ -198,7 +198,9 @@ bake.step_nnmf_sparse <- function(object, new_data, ...) {
proj_data <- as.matrix(new_data[, object$res$x_vars, drop = FALSE])
proj_data <- proj_data %*% object$res$w
colnames(proj_data) <- names0(ncol(proj_data), object$prefix)
new_data <- bind_cols(new_data, as_tibble(proj_data))
proj_data <- as_tibble(proj_data)
proj_data <- check_name(proj_data, new_data, object)
new_data <- bind_cols(new_data, proj_data)
keep_original_cols <- get_keep_original_cols(object)

if (!keep_original_cols) {
Expand Down
4 changes: 3 additions & 1 deletion R/ns.R
Expand Up @@ -173,7 +173,9 @@ bake.step_ns <- function(object, new_data, ...) {
strt <- max(cols) + 1
new_data[, orig_var] <- NULL
}
new_data <- bind_cols(new_data, as_tibble(ns_values))
ns_values <- as_tibble(ns_values)
ns_values <- check_name(ns_values, new_data, object, names(ns_values))
new_data <- bind_cols(new_data, ns_values)
new_data
}

Expand Down
3 changes: 2 additions & 1 deletion R/pls.R
Expand Up @@ -381,9 +381,10 @@ bake.step_pls <- function(object, new_data, ...) {
}

names(comps) <- names0(ncol(comps), object$prefix)
comps <- as_tibble(comps)
comps <- check_name(comps, new_data, object)

new_data <- bind_cols(new_data, as_tibble(comps))
new_data <- bind_cols(new_data, comps)
keep_original_cols <- get_keep_original_cols(object)

# Old pls never preserved original columns,
Expand Down
3 changes: 2 additions & 1 deletion R/poly.R
Expand Up @@ -171,7 +171,8 @@ bake.step_poly <- function(object, new_data, ...) {
new_tbl[i_new_names] <- new_cols
}

new_data <- dplyr::bind_cols(new_data, new_tbl)
new_tbl <- check_name(new_tbl, new_data, object, names(new_tbl))
new_data <- bind_cols(new_data, new_tbl)
new_data <- dplyr::select(new_data, -dplyr::all_of(col_names))
new_data
}
Expand Down
1 change: 1 addition & 0 deletions R/poly_bernstein.R
Expand Up @@ -151,6 +151,7 @@ bake.step_poly_bernstein <- function(object, new_data, ...) {
orig_names <- names(object$results)
if (length(orig_names) > 0) {
new_cols <- purrr::map2_dfc(object$results, new_data[, orig_names], spline2_apply)
new_cols <- check_name(new_cols, new_data, object, names(new_cols))
new_data <- bind_cols(new_data, new_cols)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
Expand Down
3 changes: 2 additions & 1 deletion R/ratio.R
Expand Up @@ -159,9 +159,10 @@ bake.step_ratio <- function(object, new_data, ...) {

res <- tibble::new_tibble(res, nrow = nrow(new_data))

keep_original_cols <- get_keep_original_cols(object)
res <- check_name(res, new_data, object, names(res))
new_data <- bind_cols(new_data, res)

keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
union_cols <- union(object$columns$top, object$columns$bottom)
new_data <- new_data[, !(colnames(new_data) %in% union_cols), drop = FALSE]
Expand Down
1 change: 1 addition & 0 deletions R/spline_b.R
Expand Up @@ -168,6 +168,7 @@ bake.step_spline_b <- function(object, new_data, ...) {
orig_names <- names(object$results)
if (length(orig_names) > 0) {
new_cols <- purrr::map2_dfc(object$results, new_data[, orig_names], spline2_apply)
new_cols <- check_name(new_cols, new_data, object, names(new_cols))
new_data <- bind_cols(new_data, new_cols)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
Expand Down
1 change: 1 addition & 0 deletions R/spline_convex.R
Expand Up @@ -159,6 +159,7 @@ bake.step_spline_convex <- function(object, new_data, ...) {
orig_names <- names(object$results)
if (length(orig_names) > 0) {
new_cols <- purrr::map2_dfc(object$results, new_data[, orig_names], spline2_apply)
new_cols <- check_name(new_cols, new_data, object, names(new_cols))
new_data <- bind_cols(new_data, new_cols)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
Expand Down
1 change: 1 addition & 0 deletions R/spline_monotone.R
Expand Up @@ -160,6 +160,7 @@ bake.step_spline_monotone <- function(object, new_data, ...) {
orig_names <- names(object$results)
if (length(orig_names) > 0) {
new_cols <- purrr::map2_dfc(object$results, new_data[, orig_names], spline2_apply)
new_cols <- check_name(new_cols, new_data, object, names(new_cols))
new_data <- bind_cols(new_data, new_cols)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
Expand Down
1 change: 1 addition & 0 deletions R/spline_natural.R
Expand Up @@ -151,6 +151,7 @@ bake.step_spline_natural <- function(object, new_data, ...) {
orig_names <- names(object$results)
if (length(orig_names) > 0) {
new_cols <- purrr::map2_dfc(object$results, new_data[, orig_names], spline2_apply)
new_cols <- check_name(new_cols, new_data, object, names(new_cols))
new_data <- bind_cols(new_data, new_cols)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
Expand Down
1 change: 1 addition & 0 deletions R/spline_nonnegative.R
Expand Up @@ -171,6 +171,7 @@ bake.step_spline_nonnegative <- function(object, new_data, ...) {
orig_names <- names(object$results)
if (length(orig_names) > 0) {
new_cols <- purrr::map2_dfc(object$results, new_data[, orig_names], spline2_apply)
new_cols <- check_name(new_cols, new_data, object, names(new_cols))
new_data <- bind_cols(new_data, new_cols)
keep_original_cols <- get_keep_original_cols(object)
if (!keep_original_cols) {
Expand Down
1 change: 1 addition & 0 deletions R/time.R
Expand Up @@ -143,6 +143,7 @@ bake.step_time <- function(object, new_data, ...) {
)

names(time_values) <- glue("{column}_{names(time_values)}")
time_values <- check_name(time_values, new_data, object, names(time_values))
new_data <- bind_cols(new_data, time_values)
}

Expand Down
14 changes: 13 additions & 1 deletion man/check_name.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.