New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain how to interpret output of .predict() in API doc #15

Closed
MaxPowerWasTaken opened this Issue Nov 14, 2017 · 4 comments

Comments

Projects
None yet
3 participants
@MaxPowerWasTaken

MaxPowerWasTaken commented Nov 14, 2017

(I also posted this as a question on Stack Overflow: https://stackoverflow.com/q/47274356/1870832 )

I'm confused how to interpret the output of .predict from a fitted CoxnetSurvivalAnalysis model in scikit-survival. I've read through the notebook Intro to Survival Analysis in scikit-survival and the API reference, but can't find an explanation. Below is a minimal example of what leads to my confusion:

import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis

# load data
data_X, data_y = load_veterans_lung_cancer()

# one-hot-encode categorical columns in X
categorical_cols = ['Celltype', 'Prior_therapy', 'Treatment']

X = data_X.copy()
for c in categorical_cols:
    dummy_matrix = pd.get_dummies(X[c], prefix=c, drop_first=False)
    X = pd.concat([X, dummy_matrix], axis=1).drop(c, axis=1)

# display final X to fit Cox Elastic Net model on
del data_X
print(X.head(3))

so here's the X going into the model:

   Age_in_years  Celltype  Karnofsky_score  Months_from_Diagnosis  \
0          69.0  squamous             60.0                    7.0   
1          64.0  squamous             70.0                    5.0   
2          38.0  squamous             60.0                    3.0   

  Prior_therapy Treatment  
0            no  standard  
1           yes  standard  
2            no  standard  

...moving on to fitting model and generating predictions:

# Fit Model
coxnet_model = CoxnetSurvivalAnalysis()
coxnet.fit(X, data_y)    

# What are these predictions?    
preds = coxnet.predict(X)

preds has same number of records as X, but their values are wayyy different than the values in data_y, even when predicted on the same data they were fit on.

print(preds.mean()) 
print(data_y['Survival_in_days'].mean())

output:

-0.044114643249153422
121.62773722627738

So what exactly are preds? Clearly .predict means something pretty different here than in scikit-learn, but I can't figure out what. The API Reference says it returns "The predicted decision function," but what does that mean? And how do I get to the predicted estimate in months yhat for a given X? I'm new to survival analysis so I'm obviously missing something.

@pavopax

This comment has been minimized.

Contributor

pavopax commented Nov 14, 2017

AFAIK:

It returns a type of risk score. Higher value means higher risk of your event (class value = True).

https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.linear_model.CoxPHSurvivalAnalysis.html#sksurv.linear_model.CoxPHSurvivalAnalysis

Recall that with censored data, your y is not just a continuous value (time) but also an indicator (censor) and risk is a summary of these two components.

You were probably looking for a predicted time. You can get the predicted survival function with estimator.predict_survival_function as in the example 00 notebook.

From that, an alternative prediction could be a median predicted survival time.

EDIT: Actually, I’m trying to extract this but it’s been a bit of a pain to munge. @sebp do you have thoughts on extracting a median predicted survival time?

@sebp

This comment has been minimized.

Owner

sebp commented Nov 15, 2017

It seems the API documentation needs to be improved to explain how to interpret predictions.

The semantics of predictions in survival models are very different from traditional machine learning models, such as those from scikit-learn. As @pavopax explained, predictions are risk scores on an arbitrary scale, which means you can usually only determine the sequence of events, but not their exact time.

The main reason for this is that you would need to specify a parametric distribution of the survival function when estimating the probability of survival at a given time. This step is not necessary when using Cox's proportional hazards, because of its proportional hazards assumption. You can use predict_survival_function or predict_cumulative_hazard_function.

Median survival time has essentially the same problem, because it is defined as the 50th percentile of a distribution, which needs to be specified.

If you are using a model other than Cox's model and still want to get a predicted survival curve, a work-around would be to use the predicted risk score as a feature in Cox's model and get a predicted survival curve this way. Obviously, you can't avoid the proportional hazards assumption this way. Other ways are possible, e.g. Van Belle et al. fit a monotone least squares support vector regression model to estimate the survival function from predicted scores.

@sebp sebp changed the title from How to interpret output of .predict() from fitted scikit-survival model in python? to Explain how to interpret output of .predict() in API doc Nov 15, 2017

@sebp

This comment has been minimized.

Owner

sebp commented Nov 19, 2017

I added more explanations to the documentation, which hopefully clarifies your questions. Please have a look and let me know if parts are unclear or important information is missing.

@MaxPowerWasTaken

This comment has been minimized.

MaxPowerWasTaken commented Nov 20, 2017

Hey thanks a lot sebp. Looking over this now. (Just a heads up, looks like some of the mathjax didn't render as intended)

And thanks a lot as well to pavopax for your earlier explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment