Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fresh implementation aft and cox #260

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

brunocarlin
Copy link

I have fixed the bug with the previous implementation it was a parse vs rlang thing

library(tidymodels)
#> Warning in system("timedatectl", intern = TRUE): running command 'timedatectl'
#> had status 1
library(censored)
#> Loading required package: survival
library(tidyverse)
library(survival)

data(cancer)

lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]


test_aft <-
  boost_tree()|> set_engine('xgboost') |> set_mode('censored regression')

test_aft |>
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     nthread = 1, verbose = 0, objective = "survival:aft")

set.seed(1)
bt_fit <- test_aft %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 36.3 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:aft"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:aft", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_aft_nloglik
#>        1            14.698676
#>        2             9.883017
#> ---                          
#>       14             4.564666
#>       15             4.553461

predict(
  bt_fit,
  lung_test,
  type = "linear_pred",
)
#> Error in `object$spec$method$pred$linear_pred$pre()`:
#> ! The objective should be survival:cox not survival:aft
#> Backtrace:
#>     ▆
#>  1. ├─stats::predict(bt_fit, lung_test, type = "linear_pred", )
#>  2. └─parsnip::predict.model_fit(...)
#>  3.   ├─parsnip::predict_linear_pred(...)
#>  4.   └─parsnip::predict_linear_pred.model_fit(...)
#>  5.     └─object$spec$method$pred$linear_pred$pre(new_data, object)
#>  6.       └─rlang::abort(glue::glue("The objective should be survival:cox not {object$fit$params$objective}")) at censored/R/boost_tree-data.R:280:8

predict(bt_fit,lung_test,type = 'time')
#> # A tibble: 5 × 1
#>   .pred_time
#>        <dbl>
#> 1      420. 
#> 2      239. 
#> 3      120. 
#> 4       78.7
#> 5      350.

test_cox <-
  boost_tree()|> set_engine('xgboost',objective = 'survival:cox')  |> set_mode('censored regression')

test_cox |>
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Engine-Specific Arguments:
#>   objective = survival:cox
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     objective = "survival:cox", nthread = 1, verbose = 0)

set.seed(1)
bt_fit <- test_cox %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 40.5 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:cox"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:cox", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_cox_nloglik
#>        1             3.967019
#>        2             3.840237
#> ---                          
#>       14             3.095573
#>       15             3.054746


predict(bt_fit, lung_test, type = 'time')
#> Error in `object$spec$method$pred$time$pre()`:
#> ! The objective should be survival:aft not survival:cox
#> Backtrace:
#>     ▆
#>  1. ├─stats::predict(bt_fit, lung_test, type = "time")
#>  2. └─parsnip::predict.model_fit(bt_fit, lung_test, type = "time")
#>  3.   ├─parsnip::predict_time(object = object, new_data = new_data, ...)
#>  4.   └─parsnip::predict_time.model_fit(...)
#>  5.     └─object$spec$method$pred$time$pre(new_data, object)
#>  6.       └─rlang::abort(glue::glue("The objective should be survival:aft not {object$fit$params$objective}")) at censored/R/boost_tree-data.R:256:8

predict(bt_fit,
        lung_test,
        type = "linear_pred")
#> # A tibble: 5 × 1
#>   .pred_linear_pred
#>               <dbl>
#> 1             0.351
#> 2             4.41 
#> 3             2.23 
#> 4             4.50 
#> 5             2.36

Created on 2023-04-13 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.3 (2023-03-15)
#>  os       Ubuntu 22.04.2 LTS
#>  system   x86_64, linux-gnu
#>  ui       X11
#>  language (EN)
#>  collate  C.UTF-8
#>  ctype    C.UTF-8
#>  tz       America/Sao_Paulo
#>  date     2023-04-13
#>  pandoc   2.19.2 @ /usr/lib/rstudio-server/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.2.3)
#>  broom        * 1.0.4      2023-03-11 [1] CRAN (R 4.2.3)
#>  censored     * 0.1.1.9003 2023-04-14 [1] local
#>  class          7.3-21     2023-01-23 [4] CRAN (R 4.2.2)
#>  cli            3.6.1      2023-03-23 [1] CRAN (R 4.2.3)
#>  codetools      0.2-19     2023-02-01 [4] CRAN (R 4.2.2)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.2.3)
#>  data.table     1.14.8     2023-02-17 [1] CRAN (R 4.2.3)
#>  dials        * 1.2.0      2023-04-03 [1] CRAN (R 4.2.3)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.2.3)
#>  digest         0.6.31     2022-12-11 [1] CRAN (R 4.2.3)
#>  dplyr        * 1.1.1      2023-03-22 [1] CRAN (R 4.2.3)
#>  evaluate       0.20       2023-01-17 [1] CRAN (R 4.2.3)
#>  fansi          1.0.4      2023-01-22 [1] CRAN (R 4.2.3)
#>  fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.2.3)
#>  forcats      * 1.0.0      2023-01-29 [1] CRAN (R 4.2.3)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.2.3)
#>  fs             1.6.1      2023-02-06 [1] CRAN (R 4.2.3)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.2.3)
#>  future         1.32.0     2023-03-07 [1] CRAN (R 4.2.3)
#>  future.apply   1.10.0     2022-11-05 [1] CRAN (R 4.2.3)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.2.3)
#>  ggplot2      * 3.4.2      2023-04-03 [1] CRAN (R 4.2.3)
#>  globals        0.16.2     2022-11-21 [1] CRAN (R 4.2.3)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.2.3)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.2.3)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.2.3)
#>  gtable         0.3.3      2023-03-21 [1] CRAN (R 4.2.3)
#>  hardhat        1.3.0      2023-03-30 [1] CRAN (R 4.2.3)
#>  hms            1.1.3      2023-03-21 [1] CRAN (R 4.2.3)
#>  htmltools      0.5.5      2023-03-23 [1] CRAN (R 4.2.3)
#>  infer        * 1.0.4      2022-12-02 [1] CRAN (R 4.2.3)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.2.3)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.2.3)
#>  jsonlite       1.8.4      2022-12-06 [1] CRAN (R 4.2.3)
#>  knitr          1.42       2023-01-25 [1] CRAN (R 4.2.3)
#>  lattice        0.20-45    2021-09-22 [4] CRAN (R 4.2.0)
#>  lava           1.7.2.1    2023-02-27 [1] CRAN (R 4.2.3)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.2.3)
#>  lifecycle      1.0.3      2022-10-07 [1] CRAN (R 4.2.3)
#>  listenv        0.9.0      2022-12-16 [1] CRAN (R 4.2.3)
#>  lubridate    * 1.9.2      2023-02-10 [1] CRAN (R 4.2.3)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.2.3)
#>  MASS           7.3-58.3   2023-03-07 [4] CRAN (R 4.2.3)
#>  Matrix         1.5-1      2022-09-13 [4] CRAN (R 4.2.1)
#>  modeldata    * 1.1.0      2023-01-25 [1] CRAN (R 4.2.3)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.2.3)
#>  nnet           7.3-18     2022-09-28 [4] CRAN (R 4.2.1)
#>  parallelly     1.35.0     2023-03-23 [1] CRAN (R 4.2.3)
#>  parsnip      * 1.1.0      2023-04-12 [1] CRAN (R 4.2.3)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.2.3)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.2.3)
#>  prodlim        2023.03.31 2023-04-02 [1] CRAN (R 4.2.3)
#>  purrr        * 1.0.1      2023-01-10 [1] CRAN (R 4.2.3)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.2.3)
#>  Rcpp           1.0.10     2023-01-22 [1] CRAN (R 4.2.3)
#>  readr        * 2.1.4      2023-02-10 [1] CRAN (R 4.2.3)
#>  recipes      * 1.0.5      2023-02-20 [1] CRAN (R 4.2.3)
#>  reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.3)
#>  rlang          1.1.0      2023-03-14 [1] CRAN (R 4.2.3)
#>  rmarkdown      2.21       2023-03-26 [1] CRAN (R 4.2.3)
#>  rpart          4.1.19     2022-10-21 [4] CRAN (R 4.2.1)
#>  rsample      * 1.1.1      2022-12-07 [1] CRAN (R 4.2.3)
#>  rstudioapi     0.14       2022-08-22 [1] CRAN (R 4.2.3)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.2.3)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.2.3)
#>  stringi        1.7.12     2023-01-11 [1] CRAN (R 4.2.3)
#>  stringr      * 1.5.0      2022-12-02 [1] CRAN (R 4.2.3)
#>  survival     * 3.5-3      2023-02-12 [4] CRAN (R 4.2.2)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.2.3)
#>  tidymodels   * 1.0.0      2022-07-13 [1] CRAN (R 4.2.3)
#>  tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.2.3)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.2.3)
#>  tidyverse    * 2.0.0      2023-02-22 [1] CRAN (R 4.2.3)
#>  timechange     0.2.0      2023-01-11 [1] CRAN (R 4.2.3)
#>  timeDate       4022.108   2023-01-07 [1] CRAN (R 4.2.3)
#>  tune         * 1.1.1      2023-04-11 [1] CRAN (R 4.2.3)
#>  tzdb           0.3.0      2022-03-28 [1] CRAN (R 4.2.3)
#>  utf8           1.2.3      2023-01-31 [1] CRAN (R 4.2.3)
#>  vctrs          0.6.1      2023-03-22 [1] CRAN (R 4.2.3)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.2.3)
#>  workflows    * 1.1.3      2023-02-22 [1] CRAN (R 4.2.3)
#>  workflowsets * 1.0.1      2023-04-06 [1] CRAN (R 4.2.3)
#>  xfun           0.38       2023-03-24 [1] CRAN (R 4.2.3)
#>  xgboost        1.7.5.1    2023-03-30 [1] CRAN (R 4.2.3)
#>  yaml           2.3.7      2023-01-23 [1] CRAN (R 4.2.3)
#>  yardstick    * 1.1.0      2022-09-07 [1] CRAN (R 4.2.3)
#> 
#>  [1] /home/bcarlin/R/x86_64-pc-linux-gnu-library/4.2
#>  [2] /usr/local/lib/R/site-library
#>  [3] /usr/lib/R/site-library
#>  [4] /usr/lib/R/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant