Skip to content

handle NULL objective in xgb_predict() #875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 23, 2023
Merged

handle NULL objective in xgb_predict() #875

merged 3 commits into from
Feb 23, 2023

Conversation

simonpcouch
Copy link
Contributor

This PR ought not to be merged before the upcoming 1.0.4 release.

Closes #873!

This PR:

  • Handles NULL objectives in xgb_predict()
  • Tests an additional type of alternative objective and tests each alternative objective not only for fitting but for predicting

The reason we need to be able to handle NULL objectives is because the xgb.Booster object will not carry along the objective argument in its params slot if's not one of the pre-defined options. Using the example fit and custom objective fn from the xgb.train() examples:

library(tidymodels)
library(xgboost)

data(agaricus.train, package = "xgboost")
data(agaricus.test, package = "xgboost")

dtrain <- with(agaricus.train, xgb.DMatrix(data, label = label, nthread = 2))
dtest <- with(agaricus.test, xgb.DMatrix(data, label = label, nthread = 2))
watchlist <- list(train = dtrain, eval = dtest)

The objective parameter is retained, usually:

param <- list(max_depth = 2, eta = 1, nthread = 2,
              objective = "binary:logistic", eval_metric = "auc")
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
#> [1]  train-auc:0.958228  eval-auc:0.960373 
#> [2]  train-auc:0.981413  eval-auc:0.979930

bst$params$objective
#> [1] "binary:logistic"

# But is not brought along when a custom objective function is used:
logregobj <- function(preds, dtrain) {
  labels <- getinfo(dtrain, "label")
  preds <- 1/(1 + exp(-preds))
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}

param$objective <- logregobj

bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
#> [1]  train-auc:0.958228  eval-auc:0.960373 
#> [2]  train-auc:0.981413  eval-auc:0.979930

bst$params$objective
#> NULL

The "more minimal" reprex from the original issue, with this PR installed:

res <-
  fit_resamples(
    boost_tree("classification") %>% set_engine("xgboost", objective = logregobj),
    Class ~ .,
    bootstraps(two_class_dat, 5)
  )

collect_metrics(res)
#> # A tibble: 2 × 6
#>   .metric  .estimator  mean     n std_err .config             
#>   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 accuracy binary     0.805     5 0.00676 Preprocessor1_Model1
#> 2 roc_auc  binary     0.865     5 0.00684 Preprocessor1_Model1

Created on 2023-02-21 with reprex v2.0.2

@simonpcouch
Copy link
Contributor Author

Does this do the trick for you, @SHo-JANG? You can install and test by running pak::pak("tidymodels/parsnip#875"), restarting R, and running your reprex.

@SHo-JANG
Copy link

SHo-JANG commented Feb 21, 2023

It works surprisingly well.
Could you modify this method so that it can be applied to other models? such as Lightgbm and so on.

Thank you so much!😂

@simonpcouch
Copy link
Contributor Author

Glad this is helpful for you! xgb_predict() is intended as a light wrapper around xgboost's predict method, so I think the scope of usage for this function will stay the same for now.

@simonpcouch simonpcouch requested a review from hfrick February 21, 2023 18:22
Copy link
Member

@hfrick hfrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@simonpcouch simonpcouch merged commit 927defa into main Feb 23, 2023
@simonpcouch simonpcouch deleted the objective-873 branch February 23, 2023 20:12
@github-actions
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 10, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Try Customizing objective function into XGboost boost_tree()
3 participants