-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
multi_predict()
for coxnet models
#70
Conversation
so that we don't need to loop over penalty values in `multi_predict()`
plus a bit of streamlining
R/proportional_hazards.R
Outdated
if (type != "linear_pred"){ | ||
pred <- predict(object, new_data = new_data, type = type, ..., | ||
penalty = penalty, multi = TRUE) | ||
} else { | ||
pred <- predict(object, new_data = new_data, type = "raw", | ||
opts = dots, penalty = penalty, multi = TRUE) | ||
|
||
# post-processing into nested tibble |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are currently predict methods for types "linear_pred"
and "survival"
. For the "linear_pred"
-type predictions, this follows what parsnip does for linear_reg() with a glmnet engine. For the survival probabilities, we use the survival curves from survfit()
and have the wrapper survival_prob_coxnet()
already so I extended that one to be able to deal with a vector of penalties. This also allows for convenient minimal nesting where we only group according to strata, see also #47 and #63.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reviewed to the best of my ability (more on the technical R side than the statistical side), so it will definitely be nice to have Emil review as well! But looks great!
+1 for adding PR comments to your own PR, it was really nice to have somewhere to add extra relevant comments
R/proportional_hazards.R
Outdated
if (type != "linear_pred"){ | ||
pred <- predict(object, new_data = new_data, type = type, ..., | ||
penalty = penalty, multi = TRUE) | ||
} else { | ||
pred <- predict(object, new_data = new_data, type = "raw", | ||
opts = dots, penalty = penalty, multi = TRUE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this else statement is pretty long, you might consider making two helpers for the specific predict types and doing:
switch(
type,
linear_pred = multi_predict_coxnet_linear_pred(...),
survival = multi_predict_coxnet_survival(...),
abort("Internal error: Unknown `type`.")
)
That would also make it easier to extend if we get more types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved the code for linear_pred
into its helper function 👍 Regarding the switch statement: I think the other types may require a similar structure as the survival probabilities so hopefully the rest is just predict(type = type)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks correct to my eyes, well done!
I had a small note about possible performance improvements, but that can also be solved later on.
This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
This PR implements
multi_predict()
for coxnet models, for types"survival"
and"linear_pred"
.It's the missing part to #33.
Created on 2021-06-28 by the reprex package (v2.0.0)