From bb5e98e17202abeecf076a8ca4578804d7846d6d Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 30 Jul 2020 11:05:18 -0400 Subject: [PATCH 1/2] make the removal work with spark --- R/predict.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/predict.R b/R/predict.R index 61193fae2..6bc7253bb 100644 --- a/R/predict.R +++ b/R/predict.R @@ -245,7 +245,7 @@ prepare_data <- function(object, new_data) { dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% dplyr::pull(remove_intercept) if (remove_intercept) { - new_data <- new_data[, colnames(new_data) != "(Intercept)", drop = FALSE] + new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)")) } switch( From 57d3ca589edbca0b8a66b9ddc2c1607088269cd0 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 30 Jul 2020 11:13:21 -0400 Subject: [PATCH 2/2] check for intercept before trying to remove --- R/predict.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/predict.R b/R/predict.R index 6bc7253bb..4492cae16 100644 --- a/R/predict.R +++ b/R/predict.R @@ -244,7 +244,7 @@ prepare_data <- function(object, new_data) { get_encoding(class(object$spec)[1]) %>% dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% dplyr::pull(remove_intercept) - if (remove_intercept) { + if (remove_intercept & any(grepl("Intercept", names(new_data)))) { new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)")) }