Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi_predict() for coxnet models #70

Merged
merged 15 commits into from
Jul 7, 2021
Merged

multi_predict() for coxnet models #70

merged 15 commits into from
Jul 7, 2021

Conversation

hfrick
Copy link
Member

@hfrick hfrick commented Jun 29, 2021

This PR implements multi_predict() for coxnet models, for types "survival" and "linear_pred".

It's the missing part to #33.

library(censored)
#> Loading required package: parsnip
library(survival)

spec_g <- proportional_hazards(penalty = 0.123) %>% 
  set_engine("glmnet")

fit_g <- fit(spec_g, 
             Surv(stop, event) ~ rx + size + number + strata(enum),
             data = bladder)

new_data_3 <- bladder[1:3, ]

mp_s <- multi_predict(fit_g, new_data_3, penalty = c(0.05, 0.1),
                    type = "survival", time = c(5, 10))
mp_s
#> # A tibble: 3 x 1
#>   .pred           
#>   <list>          
#> 1 <tibble [4 × 3]>
#> 2 <tibble [4 × 3]>
#> 3 <tibble [4 × 3]>
tidyr::unnest(mp_s, cols = .pred)
#> # A tibble: 12 x 3
#>    penalty .time .pred_survival
#>      <dbl> <dbl>          <dbl>
#>  1    0.05     5          0.727
#>  2    0.1      5          0.721
#>  3    0.05    10          0.628
#>  4    0.1     10          0.623
#>  5    0.05     5          0.989
#>  6    0.1      5          0.989
#>  7    0.05    10          0.931
#>  8    0.1     10          0.929
#>  9    0.05     5          0.989
#> 10    0.1      5          0.988
#> 11    0.05    10          0.966
#> 12    0.1     10          0.964

mp_lp <- multi_predict(fit_g, new_data_3, penalty = c(0.05, 0.1),
                       type = "linear_pred")
mp_lp
#> # A tibble: 3 x 1
#>   .pred           
#>   <list>          
#> 1 <tibble [2 × 2]>
#> 2 <tibble [2 × 2]>
#> 3 <tibble [2 × 2]>
tidyr::unnest(mp_lp, cols = .pred)
#> # A tibble: 6 x 2
#>   penalty .pred_linear_pred
#>     <dbl>             <dbl>
#> 1    0.05           -0.0701
#> 2    0.1             0.0461
#> 3    0.05           -0.0701
#> 4    0.1             0.0461
#> 5    0.05           -0.0701
#> 6    0.1             0.0461

Created on 2021-06-28 by the reprex package (v2.0.0)

Comment on lines 241 to 248
if (type != "linear_pred"){
pred <- predict(object, new_data = new_data, type = type, ...,
penalty = penalty, multi = TRUE)
} else {
pred <- predict(object, new_data = new_data, type = "raw",
opts = dots, penalty = penalty, multi = TRUE)

# post-processing into nested tibble
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are currently predict methods for types "linear_pred" and "survival". For the "linear_pred"-type predictions, this follows what parsnip does for linear_reg() with a glmnet engine. For the survival probabilities, we use the survival curves from survfit() and have the wrapper survival_prob_coxnet() already so I extended that one to be able to deal with a vector of penalties. This also allows for convenient minimal nesting where we only group according to strata, see also #47 and #63.

Copy link
Member

@DavisVaughan DavisVaughan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reviewed to the best of my ability (more on the technical R side than the statistical side), so it will definitely be nice to have Emil review as well! But looks great!

+1 for adding PR comments to your own PR, it was really nice to have somewhere to add extra relevant comments

R/proportional_hazards.R Outdated Show resolved Hide resolved
R/aaa_survival_prop.R Show resolved Hide resolved
Comment on lines 265 to 270
if (type != "linear_pred"){
pred <- predict(object, new_data = new_data, type = type, ...,
penalty = penalty, multi = TRUE)
} else {
pred <- predict(object, new_data = new_data, type = "raw",
opts = dots, penalty = penalty, multi = TRUE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this else statement is pretty long, you might consider making two helpers for the specific predict types and doing:

switch(
  type,
  linear_pred = multi_predict_coxnet_linear_pred(...),
  survival = multi_predict_coxnet_survival(...),
  abort("Internal error: Unknown `type`.")
)

That would also make it easier to extend if we get more types

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the code for linear_pred into its helper function 👍 Regarding the switch statement: I think the other types may require a similar structure as the survival probabilities so hopefully the rest is just predict(type = type).

R/proportional_hazards.R Outdated Show resolved Hide resolved
R/proportional_hazards.R Outdated Show resolved Hide resolved
tests/testthat/test-proportional_hazards-glmnet.R Outdated Show resolved Hide resolved
tests/testthat/test-proportional_hazards-glmnet.R Outdated Show resolved Hide resolved
@hfrick hfrick mentioned this pull request Jul 1, 2021
Copy link
Member

@EmilHvitfeldt EmilHvitfeldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks correct to my eyes, well done!
I had a small note about possible performance improvements, but that can also be solved later on.

R/aaa_survival_prop.R Show resolved Hide resolved
R/aaa_survival_prop.R Outdated Show resolved Hide resolved
@hfrick hfrick merged commit 253533f into master Jul 7, 2021
@hfrick hfrick deleted the coxnet-multi-predict branch July 7, 2021 09:58
@github-actions
Copy link

github-actions bot commented Nov 5, 2021

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Nov 5, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants