diff --git a/R/predict.R b/R/predict.R index 61193fae2..4492cae16 100644 --- a/R/predict.R +++ b/R/predict.R @@ -244,8 +244,8 @@ prepare_data <- function(object, new_data) { get_encoding(class(object$spec)[1]) %>% dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% dplyr::pull(remove_intercept) - if (remove_intercept) { - new_data <- new_data[, colnames(new_data) != "(Intercept)", drop = FALSE] + if (remove_intercept & any(grepl("Intercept", names(new_data)))) { + new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)")) } switch(