Join GitHub today
GitHub is home to over 28 million developers working together to host and review code, manage projects, and build software together.Sign up
Explain how to interpret output of .predict() in API doc #15
(I also posted this as a question on Stack Overflow: https://stackoverflow.com/q/47274356/1870832 )
I'm confused how to interpret the output of .predict from a fitted CoxnetSurvivalAnalysis model in scikit-survival. I've read through the notebook Intro to Survival Analysis in scikit-survival and the API reference, but can't find an explanation. Below is a minimal example of what leads to my confusion:
so here's the X going into the model:
...moving on to fitting model and generating predictions:
preds has same number of records as X, but their values are wayyy different than the values in data_y, even when predicted on the same data they were fit on.
So what exactly are preds? Clearly .predict means something pretty different here than in scikit-learn, but I can't figure out what. The API Reference says it returns "The predicted decision function," but what does that mean? And how do I get to the predicted estimate in months yhat for a given X? I'm new to survival analysis so I'm obviously missing something.
It returns a type of risk score. Higher value means higher risk of your event (class value = True).
Recall that with censored data, your
You were probably looking for a predicted time. You can get the predicted survival function with
From that, an alternative prediction could be a median predicted survival time.
EDIT: Actually, I’m trying to extract this but it’s been a bit of a pain to munge. @sebp do you have thoughts on extracting a median predicted survival time?
It seems the API documentation needs to be improved to explain how to interpret predictions.
The semantics of predictions in survival models are very different from traditional machine learning models, such as those from scikit-learn. As @pavopax explained, predictions are risk scores on an arbitrary scale, which means you can usually only determine the sequence of events, but not their exact time.
The main reason for this is that you would need to specify a parametric distribution of the survival function when estimating the probability of survival at a given time. This step is not necessary when using Cox's proportional hazards, because of its proportional hazards assumption. You can use predict_survival_function or predict_cumulative_hazard_function.
Median survival time has essentially the same problem, because it is defined as the 50th percentile of a distribution, which needs to be specified.
If you are using a model other than Cox's model and still want to get a predicted survival curve, a work-around would be to use the predicted risk score as a feature in Cox's model and get a predicted survival curve this way. Obviously, you can't avoid the proportional hazards assumption this way. Other ways are possible, e.g. Van Belle et al. fit a monotone least squares support vector regression model to estimate the survival function from predicted scores.