Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fresh implementation aft and cox #260

wants to merge 2 commits into from


Copy link

I have fixed the bug with the previous implementation it was a parse vs rlang thing

#> Warning in system("timedatectl", intern = TRUE): running command 'timedatectl'
#> had status 1
#> Loading required package: survival


lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]

test_aft <-
  boost_tree()|> set_engine('xgboost') |> set_mode('censored regression')

test_aft |>
#> Boosted Tree Model Specification (censored regression)
#> Computational engine: xgboost 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     nthread = 1, verbose = 0, objective = "survival:aft")

bt_fit <- test_aft %>% fit(Surv(time, status) ~ ., data = lung_train)
#> parsnip model object
#> ##### xgb.Booster
#> raw: 36.3 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:aft"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:aft", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_aft_nloglik
#>        1            14.698676
#>        2             9.883017
#> ---                          
#>       14             4.564666
#>       15             4.553461

  type = "linear_pred",
#> Error in `object$spec$method$pred$linear_pred$pre()`:
#> ! The objective should be survival:cox not survival:aft
#> Backtrace:
#>     ▆
#>  1. ├─stats::predict(bt_fit, lung_test, type = "linear_pred", )
#>  2. └─parsnip::predict.model_fit(...)
#>  3.   ├─parsnip::predict_linear_pred(...)
#>  4.   └─parsnip::predict_linear_pred.model_fit(...)
#>  5.     └─object$spec$method$pred$linear_pred$pre(new_data, object)
#>  6.       └─rlang::abort(glue::glue("The objective should be survival:cox not {object$fit$params$objective}")) at censored/R/boost_tree-data.R:280:8

predict(bt_fit,lung_test,type = 'time')
#> # A tibble: 5 × 1
#>   .pred_time
#>        <dbl>
#> 1      420. 
#> 2      239. 
#> 3      120. 
#> 4       78.7
#> 5      350.

test_cox <-
  boost_tree()|> set_engine('xgboost',objective = 'survival:cox')  |> set_mode('censored regression')

test_cox |>
#> Boosted Tree Model Specification (censored regression)
#> Engine-Specific Arguments:
#>   objective = survival:cox
#> Computational engine: xgboost 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     objective = "survival:cox", nthread = 1, verbose = 0)

bt_fit <- test_cox %>% fit(Surv(time, status) ~ ., data = lung_train)
#> parsnip model object
#> ##### xgb.Booster
#> raw: 40.5 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:cox"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:cox", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_cox_nloglik
#>        1             3.967019
#>        2             3.840237
#> ---                          
#>       14             3.095573
#>       15             3.054746

predict(bt_fit, lung_test, type = 'time')
#> Error in `object$spec$method$pred$time$pre()`:
#> ! The objective should be survival:aft not survival:cox
#> Backtrace:
#>     ▆
#>  1. ├─stats::predict(bt_fit, lung_test, type = "time")
#>  2. └─parsnip::predict.model_fit(bt_fit, lung_test, type = "time")
#>  3.   ├─parsnip::predict_time(object = object, new_data = new_data, ...)
#>  4.   └─parsnip::predict_time.model_fit(...)
#>  5.     └─object$spec$method$pred$time$pre(new_data, object)
#>  6.       └─rlang::abort(glue::glue("The objective should be survival:aft not {object$fit$params$objective}")) at censored/R/boost_tree-data.R:256:8

        type = "linear_pred")
#> # A tibble: 5 × 1
#>   .pred_linear_pred
#>               <dbl>
#> 1             0.351
#> 2             4.41 
#> 3             2.23 
#> 4             4.50 
#> 5             2.36

Created on 2023-04-13 with reprex v2.0.2

Session info
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.3 (2023-03-15)
#>  os       Ubuntu 22.04.2 LTS
#>  system   x86_64, linux-gnu
#>  ui       X11
#>  language (EN)
#>  collate  C.UTF-8
#>  ctype    C.UTF-8
#>  tz       America/Sao_Paulo
#>  date     2023-04-13
#>  pandoc   2.19.2 @ /usr/lib/rstudio-server/bin/quarto/bin/tools/ (via rmarkdown)
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.2.3)
#>  broom        * 1.0.4      2023-03-11 [1] CRAN (R 4.2.3)
#>  censored     * 2023-04-14 [1] local
#>  class          7.3-21     2023-01-23 [4] CRAN (R 4.2.2)
#>  cli            3.6.1      2023-03-23 [1] CRAN (R 4.2.3)
#>  codetools      0.2-19     2023-02-01 [4] CRAN (R 4.2.2)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.2.3)
#>  data.table     1.14.8     2023-02-17 [1] CRAN (R 4.2.3)
#>  dials        * 1.2.0      2023-04-03 [1] CRAN (R 4.2.3)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.2.3)
#>  digest         0.6.31     2022-12-11 [1] CRAN (R 4.2.3)
#>  dplyr        * 1.1.1      2023-03-22 [1] CRAN (R 4.2.3)
#>  evaluate       0.20       2023-01-17 [1] CRAN (R 4.2.3)
#>  fansi          1.0.4      2023-01-22 [1] CRAN (R 4.2.3)
#>  fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.2.3)
#>  forcats      * 1.0.0      2023-01-29 [1] CRAN (R 4.2.3)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.2.3)
#>  fs             1.6.1      2023-02-06 [1] CRAN (R 4.2.3)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.2.3)
#>  future         1.32.0     2023-03-07 [1] CRAN (R 4.2.3)
#>  future.apply   1.10.0     2022-11-05 [1] CRAN (R 4.2.3)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.2.3)
#>  ggplot2      * 3.4.2      2023-04-03 [1] CRAN (R 4.2.3)
#>  globals        0.16.2     2022-11-21 [1] CRAN (R 4.2.3)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.2.3)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.2.3)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.2.3)
#>  gtable         0.3.3      2023-03-21 [1] CRAN (R 4.2.3)
#>  hardhat        1.3.0      2023-03-30 [1] CRAN (R 4.2.3)
#>  hms            1.1.3      2023-03-21 [1] CRAN (R 4.2.3)
#>  htmltools      0.5.5      2023-03-23 [1] CRAN (R 4.2.3)
#>  infer        * 1.0.4      2022-12-02 [1] CRAN (R 4.2.3)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.2.3)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.2.3)
#>  jsonlite       1.8.4      2022-12-06 [1] CRAN (R 4.2.3)
#>  knitr          1.42       2023-01-25 [1] CRAN (R 4.2.3)
#>  lattice        0.20-45    2021-09-22 [4] CRAN (R 4.2.0)
#>  lava     2023-02-27 [1] CRAN (R 4.2.3)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.2.3)
#>  lifecycle      1.0.3      2022-10-07 [1] CRAN (R 4.2.3)
#>  listenv        0.9.0      2022-12-16 [1] CRAN (R 4.2.3)
#>  lubridate    * 1.9.2      2023-02-10 [1] CRAN (R 4.2.3)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.2.3)
#>  MASS           7.3-58.3   2023-03-07 [4] CRAN (R 4.2.3)
#>  Matrix         1.5-1      2022-09-13 [4] CRAN (R 4.2.1)
#>  modeldata    * 1.1.0      2023-01-25 [1] CRAN (R 4.2.3)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.2.3)
#>  nnet           7.3-18     2022-09-28 [4] CRAN (R 4.2.1)
#>  parallelly     1.35.0     2023-03-23 [1] CRAN (R 4.2.3)
#>  parsnip      * 1.1.0      2023-04-12 [1] CRAN (R 4.2.3)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.2.3)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.2.3)
#>  prodlim        2023.03.31 2023-04-02 [1] CRAN (R 4.2.3)
#>  purrr        * 1.0.1      2023-01-10 [1] CRAN (R 4.2.3)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.2.3)
#>  Rcpp           1.0.10     2023-01-22 [1] CRAN (R 4.2.3)
#>  readr        * 2.1.4      2023-02-10 [1] CRAN (R 4.2.3)
#>  recipes      * 1.0.5      2023-02-20 [1] CRAN (R 4.2.3)
#>  reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.3)
#>  rlang          1.1.0      2023-03-14 [1] CRAN (R 4.2.3)
#>  rmarkdown      2.21       2023-03-26 [1] CRAN (R 4.2.3)
#>  rpart          4.1.19     2022-10-21 [4] CRAN (R 4.2.1)
#>  rsample      * 1.1.1      2022-12-07 [1] CRAN (R 4.2.3)
#>  rstudioapi     0.14       2022-08-22 [1] CRAN (R 4.2.3)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.2.3)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.2.3)
#>  stringi        1.7.12     2023-01-11 [1] CRAN (R 4.2.3)
#>  stringr      * 1.5.0      2022-12-02 [1] CRAN (R 4.2.3)
#>  survival     * 3.5-3      2023-02-12 [4] CRAN (R 4.2.2)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.2.3)
#>  tidymodels   * 1.0.0      2022-07-13 [1] CRAN (R 4.2.3)
#>  tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.2.3)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.2.3)
#>  tidyverse    * 2.0.0      2023-02-22 [1] CRAN (R 4.2.3)
#>  timechange     0.2.0      2023-01-11 [1] CRAN (R 4.2.3)
#>  timeDate       4022.108   2023-01-07 [1] CRAN (R 4.2.3)
#>  tune         * 1.1.1      2023-04-11 [1] CRAN (R 4.2.3)
#>  tzdb           0.3.0      2022-03-28 [1] CRAN (R 4.2.3)
#>  utf8           1.2.3      2023-01-31 [1] CRAN (R 4.2.3)
#>  vctrs          0.6.1      2023-03-22 [1] CRAN (R 4.2.3)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.2.3)
#>  workflows    * 1.1.3      2023-02-22 [1] CRAN (R 4.2.3)
#>  workflowsets * 1.0.1      2023-04-06 [1] CRAN (R 4.2.3)
#>  xfun           0.38       2023-03-24 [1] CRAN (R 4.2.3)
#>  xgboost    2023-03-30 [1] CRAN (R 4.2.3)
#>  yaml           2.3.7      2023-01-23 [1] CRAN (R 4.2.3)
#>  yardstick    * 1.1.0      2022-09-07 [1] CRAN (R 4.2.3)
#>  [1] /home/bcarlin/R/x86_64-pc-linux-gnu-library/4.2
#>  [2] /usr/local/lib/R/site-library
#>  [3] /usr/lib/R/site-library
#>  [4] /usr/lib/R/library
#> ──────────────────────────────────────────────────────────────────────────────

@brunocarlin brunocarlin closed this Jul 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

Successfully merging this pull request may close these issues.

None yet

1 participant