diff --git a/NEWS.md b/NEWS.md index 7977fc733..dff384d84 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,7 @@ * `surv_reg()` is now defunct and will error if called. Please use `survival_reg()` instead (#1206). +* Enable parsnip to work with xgboost version > 2.0.0.0. (#1227) # parsnip 1.3.3 diff --git a/R/boost_tree.R b/R/boost_tree.R index dc0e288d0..beb11b46f 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -271,6 +271,7 @@ xgb_train <- function( event_level = c("first", "second"), ... ) { + rlang::check_installed("xgboost") event_level <- rlang::arg_match(event_level, c("first", "second")) others <- list(...) @@ -340,31 +341,70 @@ xgb_train <- function( others <- process_others(others, arg_list) + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + if (!is.null(num_class) && num_class > 2) { + arg_list$num_class <- num_class + } + + param_names <- names( + formals( + getFromNamespace("xgb.params", ns = "xgboost") + ) + ) + + if (any(param_names %in% names(others))) { + elements <- param_names[param_names %in% names(others)] + + for (element in elements) { + arg_list[[element]] <- others[[element]] + others[[element]] <- NULL + } + } + + if (is.null(arg_list$objective)) { + if (is.numeric(y)) { + arg_list$objective <- "reg:squarederror" + } else { + if (num_class == 2) { + arg_list$objective <- "binary:logistic" + } else { + arg_list$objective <- "multi:softprob" + } + } + } + } + main_args <- c( list( data = quote(x$data), - watchlist = quote(x$watchlist), params = arg_list, nrounds = nrounds, early_stopping_rounds = early_stop ), others ) + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + main_args$evals <- quote(x$watchlist) + } else { + main_args$watchlist <- quote(x$watchlist) + } - if (is.null(main_args$objective)) { - if (is.numeric(y)) { - main_args$objective <- "reg:squarederror" - } else { - if (num_class == 2) { - main_args$objective <- "binary:logistic" + if (utils::packageVersion("xgboost") < "2.0.0.0") { + if (is.null(main_args$objective)) { + if (is.numeric(y)) { + main_args$objective <- "reg:squarederror" } else { - main_args$objective <- "multi:softprob" + if (num_class == 2) { + main_args$objective <- "binary:logistic" + } else { + main_args$objective <- "multi:softprob" + } } } - } - if (!is.null(num_class) && num_class > 2) { - main_args$num_class <- num_class + if (!is.null(num_class) && num_class > 2) { + main_args$num_class <- num_class + } } call <- make_call(fun = "xgb.train", ns = "xgboost", main_args) @@ -471,6 +511,7 @@ as_xgb_data <- function( event_level = "first", ... ) { + rlang::check_installed("xgboost") lvls <- levels(y) n <- nrow(x) @@ -506,21 +547,52 @@ as_xgb_data <- function( watch_list <- list(validation = val_data) info_list <- list(label = y[trn_index]) - if (!is.null(weights)) { - info_list$weight <- weights[trn_index] + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + if (!is.null(weights)) { + dat <- xgboost::xgb.DMatrix( + data = x[trn_index, , drop = FALSE], + missing = NA, + label = y[trn_index], + weight = weights[trn_index] + ) + } else { + dat <- xgboost::xgb.DMatrix( + data = x[trn_index, , drop = FALSE], + missing = NA, + label = y[trn_index] + ) + } + } else { + if (!is.null(weights)) { + info_list$weight <- weights[trn_index] + } + dat <- xgboost::xgb.DMatrix( + data = x[trn_index, , drop = FALSE], + missing = NA, + info = info_list + ) } - dat <- xgboost::xgb.DMatrix( - data = x[trn_index, , drop = FALSE], - missing = NA, - info = info_list - ) } else { - info_list <- list(label = y) - if (!is.null(weights)) { - info_list$weight <- weights + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + if (!is.null(weights)) { + dat <- xgboost::xgb.DMatrix( + x, + missing = NA, + label = y, + weight = weights + ) + } else { + dat <- xgboost::xgb.DMatrix(x, missing = NA, label = y) + } + watch_list <- list(training = dat) + } else { + info_list <- list(label = y) + if (!is.null(weights)) { + info_list$weight <- weights + } + dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) + watch_list <- list(training = dat) } - dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) - watch_list <- list(training = dat) } } else { dat <- xgboost::setinfo(x, "label", y) @@ -579,12 +651,21 @@ multi_predict._xgb.Booster <- } xgb_by_tree <- function(tree, object, new_data, type, ...) { - pred <- xgb_predict( - object$fit, - new_data = new_data, - iterationrange = c(1, tree + 1), - ntreelimit = NULL - ) + rlang::check_installed("xgboost") + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + pred <- xgb_predict( + object$fit, + new_data = new_data, + iterationrange = c(1, tree + 1) + ) + } else { + pred <- xgb_predict( + object$fit, + new_data = new_data, + iterationrange = c(1, tree + 1), + ntreelimit = NULL + ) + } # switch based on prediction type if (object$spec$mode == "regression") { diff --git a/man/rmd/boost_tree_xgboost.Rmd b/man/rmd/boost_tree_xgboost.Rmd index 486861d26..0fbcf4dcd 100644 --- a/man/rmd/boost_tree_xgboost.Rmd +++ b/man/rmd/boost_tree_xgboost.Rmd @@ -3,7 +3,7 @@ #| include: false ``` -`r descr_models("boost_tree", "xgboost")` +`r descr_models("boost_tree", "xgboost")`. Note that in late 2025, a new version of xgboost was released with differences in its interface and model objects. This version of parsnip should work with either version. ## Tuning Parameters diff --git a/man/rmd/boost_tree_xgboost.md b/man/rmd/boost_tree_xgboost.md index be989b87b..4f1ae86d6 100644 --- a/man/rmd/boost_tree_xgboost.md +++ b/man/rmd/boost_tree_xgboost.md @@ -1,7 +1,7 @@ -For this engine, there are multiple modes: classification and regression +For this engine, there are multiple modes: classification and regression. Note that in late 2025, a new version of xgboost was released with differences in its interface and model objects. This version of parsnip should work with either version. ## Tuning Parameters diff --git a/tests/testthat/test-boost_tree_xgboost.R b/tests/testthat/test-boost_tree_xgboost.R index 0841f1faa..cdc60849c 100644 --- a/tests/testthat/test-boost_tree_xgboost.R +++ b/tests/testthat/test-boost_tree_xgboost.R @@ -8,6 +8,24 @@ hpc_xgboost <- boost_tree(trees = 2, mode = "classification") |> set_engine("xgboost") +extract_xgb_param <- function(x, param) { + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + res <- attr(extract_fit_engine(x), "params")[[param]] + } else { + res <- extract_fit_engine(x)$param[[param]] + } + res +} + +extract_xgb_evaluation_log <- function(x) { + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + res <- attr(extract_fit_engine(x), "evaluation_log") + } else { + res <- extract_fit_engine(x)[["evaluation_log"]] + } + res +} + # ------------------------------------------------------------------------------ test_that('xgboost execution, classification', { @@ -59,13 +77,21 @@ test_that('xgboost execution, classification', { ) }) - expect_equal(res_f$fit$evaluation_log, res_xy$fit$evaluation_log) - expect_equal(res_f_wts$fit$evaluation_log, res_xy_wts$fit$evaluation_log) + expect_equal( + extract_xgb_evaluation_log(res_f), + extract_xgb_evaluation_log(res_xy) + ) + expect_equal( + extract_xgb_evaluation_log(res_f_wts), + extract_xgb_evaluation_log(res_xy_wts) + ) # Check to see if the case weights had an effect expect_true( - !isTRUE(all.equal(res_f$fit$evaluation_log, res_f_wts$fit$evaluation_log)) + !isTRUE(all.equal( + extract_xgb_evaluation_log(res_f), + extract_xgb_evaluation_log(res_f_wts) + )) ) - expect_true(has_multi_predict(res_xy)) expect_equal(multi_predict_args(res_xy), "trees") @@ -209,10 +235,7 @@ test_that('xgboost regression prediction', { ) expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, -1])$.pred) - expect_equal( - extract_fit_engine(form_fit)$params$objective, - "reg:squarederror" - ) + expect_equal(extract_xgb_param(form_fit, "objective"), "reg:squarederror") }) @@ -228,10 +251,7 @@ test_that('xgboost alternate objective', { set_mode("regression") xgb_fit <- spec |> fit(mpg ~ ., data = mtcars) - expect_equal( - extract_fit_engine(xgb_fit)$params$objective, - "reg:pseudohubererror" - ) + expect_equal(extract_xgb_param(xgb_fit, "objective"), "reg:pseudohubererror") expect_no_error(xgb_preds <- predict(xgb_fit, new_data = mtcars[1, ])) expect_s3_class(xgb_preds, "data.frame") @@ -333,7 +353,7 @@ test_that('validation sets', { ) expect_equal( - colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], + colnames(extract_xgb_evaluation_log(reg_fit))[2], "validation_rmse" ) @@ -345,7 +365,7 @@ test_that('validation sets', { ) expect_equal( - colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], + colnames(extract_xgb_evaluation_log(reg_fit))[2], "validation_mae" ) @@ -357,7 +377,7 @@ test_that('validation sets', { ) expect_equal( - colnames(extract_fit_engine(reg_fit)$evaluation_log)[2], + colnames(extract_xgb_evaluation_log(reg_fit))[2], "training_mae" ) @@ -387,12 +407,29 @@ test_that('early stopping', { fit(mpg ~ ., data = mtcars[-(1:4), ]) ) + extract_xgb_nitter <- function(x) { + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + res <- nrow(attr(extract_fit_engine(x), "evaluation_log")) + } else { + res <- extract_fit_engine(reg_fit)$niter + } + res + } + extract_xgb_best_iteration <- function(x) { + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + res <- attr(extract_fit_engine(x), "early_stop")$best_iteration + } else { + res <- extract_fit_engine(reg_fit)$best_iteration + } + res + } + expect_equal( - extract_fit_engine(reg_fit)$niter - - extract_fit_engine(reg_fit)$best_iteration, + extract_xgb_nitter(reg_fit) - + extract_xgb_best_iteration(reg_fit), 5 ) - expect_true(extract_fit_engine(reg_fit)$niter < 200) + expect_true(extract_xgb_nitter(reg_fit) < 200) expect_no_condition( reg_fit <- @@ -535,16 +572,29 @@ test_that('xgboost data and sparse matrices', { from_mat$fit$handle <- NULL from_sparse$fit$handle <- NULL - expect_equal( - extract_fit_engine(from_df), - extract_fit_engine(from_mat), - ignore_function_env = TRUE - ) - expect_equal( - extract_fit_engine(from_df), - extract_fit_engine(from_sparse), - ignore_function_env = TRUE - ) + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + expect_equal( + attributes(extract_fit_engine(from_df)), + attributes(extract_fit_engine(from_mat)), + ignore_function_env = TRUE + ) + expect_equal( + attributes(extract_fit_engine(from_df)), + attributes(extract_fit_engine(from_sparse)), + ignore_function_env = TRUE + ) + } else { + expect_equal( + extract_fit_engine(from_df), + extract_fit_engine(from_mat), + ignore_function_env = TRUE + ) + expect_equal( + extract_fit_engine(from_df), + extract_fit_engine(from_sparse), + ignore_function_env = TRUE + ) + } # case weights added expect_no_condition( @@ -591,14 +641,20 @@ test_that('argument checks for data dimensions', { xy_fit <- spec |> fit_xy(x = penguins_dummy, y = penguins$species, control = ctrl) ) - expect_equal(extract_fit_engine(f_fit)$params$colsample_bynode, 1) expect_equal( - extract_fit_engine(f_fit)$params$min_child_weight, + extract_xgb_param(f_fit, "colsample_bynode"), + 1 + ) + expect_equal( + extract_xgb_param(f_fit, "min_child_weight"), nrow(penguins) ) - expect_equal(extract_fit_engine(xy_fit)$params$colsample_bynode, 1) expect_equal( - extract_fit_engine(xy_fit)$params$min_child_weight, + extract_xgb_param(xy_fit, "colsample_bynode"), + 1 + ) + expect_equal( + extract_xgb_param(xy_fit, "min_child_weight"), nrow(penguins) ) }) @@ -633,15 +689,27 @@ test_that("fit and prediction with `event_level`", { xgbmat_train_1 <- xgb.DMatrix(data = train_x, label = train_y_1) set.seed(24) - fit_xgb_1 <- xgboost::xgb.train( - data = xgbmat_train_1, - nrounds = 10, - watchlist = list("training" = xgbmat_train_1), - objective = "binary:logistic", - eval_metric = "auc", - verbose = 0 - ) - + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + fit_xgb_1 <- xgboost::xgb.train( + params = list( + objective = "binary:logistic", + eval_metric = "auc" + ), + data = xgbmat_train_1, + nrounds = 10, + evals = list("training" = xgbmat_train_1), + verbose = 0 + ) + } else { + fit_xgb_1 <- xgboost::xgb.train( + data = xgbmat_train_1, + nrounds = 10, + watchlist = list("training" = xgbmat_train_1), + objective = "binary:logistic", + eval_metric = "auc", + verbose = 0 + ) + } expect_equal( extract_fit_engine(fit_p_1)$evaluation_log, fit_xgb_1$evaluation_log @@ -661,14 +729,27 @@ test_that("fit and prediction with `event_level`", { xgbmat_train_2 <- xgb.DMatrix(data = train_x, label = train_y_2) set.seed(24) - fit_xgb_2 <- xgboost::xgb.train( - data = xgbmat_train_2, - nrounds = 10, - watchlist = list("training" = xgbmat_train_2), - objective = "binary:logistic", - eval_metric = "auc", - verbose = 0 - ) + if (utils::packageVersion("xgboost") >= "2.0.0.0") { + fit_xgb_2 <- xgboost::xgb.train( + params = list( + eval_metric = "auc", + objective = "binary:logistic" + ), + data = xgbmat_train_2, + nrounds = 10, + evals = list("training" = xgbmat_train_2), + verbose = 0 + ) + } else { + fit_xgb_2 <- xgboost::xgb.train( + data = xgbmat_train_2, + nrounds = 10, + watchlist = list("training" = xgbmat_train_2), + objective = "binary:logistic", + eval_metric = "auc", + verbose = 0 + ) + } expect_equal( extract_fit_engine(fit_p_2)$evaluation_log, @@ -691,9 +772,9 @@ test_that("count/proportion parameters", { set_engine("xgboost") |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) - expect_equal(extract_fit_engine(fit1)$params$colsample_bytree, 1) + expect_equal(extract_xgb_param(fit1, "colsample_bytree"), 1) expect_equal( - extract_fit_engine(fit1)$params$colsample_bynode, + extract_xgb_param(fit1, "colsample_bynode"), 7 / (ncol(mtcars) - 1) ) @@ -703,11 +784,11 @@ test_that("count/proportion parameters", { set_mode("regression") |> fit(mpg ~ ., data = mtcars) expect_equal( - extract_fit_engine(fit2)$params$colsample_bytree, + extract_xgb_param(fit2, "colsample_bytree"), 4 / (ncol(mtcars) - 1) ) expect_equal( - extract_fit_engine(fit2)$params$colsample_bynode, + extract_xgb_param(fit2, "colsample_bynode"), 7 / (ncol(mtcars) - 1) ) @@ -716,17 +797,18 @@ test_that("count/proportion parameters", { set_engine("xgboost") |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) - expect_equal(extract_fit_engine(fit3)$params$colsample_bytree, 1) - expect_equal(extract_fit_engine(fit3)$params$colsample_bynode, 1) + expect_equal(extract_xgb_param(fit3, "colsample_bytree"), 1) + expect_equal(extract_xgb_param(fit3, "colsample_bynode"), 1) fit4 <- boost_tree(mtry = .9, trees = 4) |> set_engine("xgboost", colsample_bytree = .1, counts = FALSE) |> set_mode("regression") |> fit(mpg ~ ., data = mtcars) - expect_equal(extract_fit_engine(fit4)$params$colsample_bytree, .1) - expect_equal(extract_fit_engine(fit4)$params$colsample_bynode, .9) + expect_equal(extract_xgb_param(fit4, "colsample_bytree"), .1) + expect_equal(extract_xgb_param(fit4, "colsample_bynode"), .9) + extract_xgb_param(fit4, "colsample_bynode") expect_snapshot( error = TRUE, boost_tree(mtry = .9, trees = 4) |> @@ -758,7 +840,7 @@ test_that('interface to param arguments', { class = "xgboost_params_warning" ) - expect_equal(extract_fit_engine(fit_1)$params$eval_metric, "mae") + expect_equal(extract_xgb_param(fit_1, "eval_metric"), "mae") # pass params as main argument (good) spec_2 <- @@ -769,7 +851,7 @@ test_that('interface to param arguments', { fit_2 <- spec_2 |> fit(mpg ~ ., data = mtcars) ) - expect_equal(extract_fit_engine(fit_2)$params$eval_metric, "mae") + expect_equal(extract_xgb_param(fit_2, "eval_metric"), "mae") # pass objective to params argument (bad) spec_3 <- @@ -781,10 +863,7 @@ test_that('interface to param arguments', { class = "xgboost_params_warning" ) - expect_equal( - extract_fit_engine(fit_3)$params$objective, - "reg:pseudohubererror" - ) + expect_equal(extract_xgb_param(fit_3, "objective"), "reg:pseudohubererror") # pass objective as main argument (good) spec_4 <- @@ -795,10 +874,7 @@ test_that('interface to param arguments', { fit_4 <- spec_4 |> fit(mpg ~ ., data = mtcars) ) - expect_equal( - extract_fit_engine(fit_4)$params$objective, - "reg:pseudohubererror" - ) + expect_equal(extract_xgb_param(fit_4, "objective"), "reg:pseudohubererror") # pass a guarded argument as a main argument (bad) spec_5 <- @@ -810,7 +886,7 @@ test_that('interface to param arguments', { class = "xgboost_guarded_warning" ) - expect_null(extract_fit_engine(fit_5)$params$watchlist) + expect_null(extract_xgb_param(fit_5, "watchlist")) # pass two guarded arguments as main arguments (bad) spec_6 <- @@ -822,7 +898,7 @@ test_that('interface to param arguments', { class = "xgboost_guarded_warning" ) - expect_null(extract_fit_engine(fit_5)$params$watchlist) + expect_null(extract_xgb_param(fit_6, "watchlist")) # pass a guarded argument as params argument (bad) spec_7 <- @@ -834,5 +910,5 @@ test_that('interface to param arguments', { class = "xgboost_params_warning" ) - expect_equal(extract_fit_engine(fit_5)$params$gamma, 0) + expect_equal(extract_xgb_param(fit_7, "gamma"), 0) })