Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# parsnip (development version)

* An RStudio add-in is availble that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.

* An RStudio add-in is availble that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.

* For `xgboost` models, users can now pass `objective` to `set_engine("xgboost")`.

# parsnip 0.1.4

* `show_engines()` will provide information on the current set for a model.
Expand Down
34 changes: 20 additions & 14 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ check_args.boost_tree <- function(object) {
#' training iterations without improvement before stopping. If `validation` is
#' used, performance is base on the validation set; otherwise, the training set
#' is used.
#' @param objective A single string (or NULL) that defines the loss function that
#' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
#' NULL, an appropriate loss function is chosen.
#' @param ... Other options to pass to `xgb.train`.
#' @return A fitted `xgboost` object.
#' @keywords internal
Expand All @@ -310,7 +313,9 @@ xgb_train <- function(
x, y,
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
early_stop = NULL, ...) {
early_stop = NULL, objective = NULL, ...) {

others <- list(...)

num_class <- length(levels(y))

Expand All @@ -327,13 +332,15 @@ xgb_train <- function(
}


if (is.numeric(y)) {
loss <- "reg:squarederror"
} else {
if (num_class == 2) {
loss <- "binary:logistic"
if (is.null(objective)) {
if (is.numeric(y)) {
objective <- "reg:squarederror"
} else {
loss <- "multi:softprob"
if (num_class == 2) {
objective <- "binary:logistic"
} else {
objective <- "multi:softprob"
}
}
}

Expand Down Expand Up @@ -370,15 +377,15 @@ xgb_train <- function(
gamma = gamma,
colsample_bytree = colsample_bytree,
min_child_weight = min(min_child_weight, n),
subsample = subsample
subsample = subsample,
objective = objective
)

main_args <- list(
data = quote(x$data),
watchlist = quote(x$watchlist),
params = arg_list,
nrounds = nrounds,
objective = loss,
early_stopping_rounds = early_stop
)
if (!is.null(num_class) && num_class > 2) {
Expand All @@ -388,7 +395,7 @@ xgb_train <- function(
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)

# override or add some other args
others <- list(...)

others <-
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
if (!(any(names(others) == "verbose"))) {
Expand All @@ -410,13 +417,12 @@ xgb_pred <- function(object, newdata, ...) {

res <- predict(object, newdata, ...)

x = switch(
x <- switch(
object$params$objective,
"reg:squarederror" = , "reg:logistic" = , "binary:logistic" = res,
"binary:logitraw" = stats::binomial()$linkinv(res),
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
res
)
res)

x
}

Expand Down
5 changes: 5 additions & 0 deletions man/xgb_train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/test_boost_tree_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,26 @@ test_that('xgboost regression prediction', {

form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1])))
expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, -1])$.pred)

expect_equal(form_fit$fit$params$objective, "reg:squarederror")

})



test_that('xgboost alternate objective', {
skip_if_not_installed("xgboost")

spec <-
boost_tree() %>%
set_engine("xgboost", objective = "reg:pseudohubererror") %>%
set_mode("regression")

xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)
expect_equal(xgb_fit$fit$params$objective, "reg:pseudohubererror")
})


test_that('submodel prediction', {

skip_if_not_installed("xgboost")
Expand Down