# Comparing various regularizations in scikit-survival

Summary:
- There seem to be 2 different implementation of Ridge between CoxPH and Coxnet


References
- https://scikit-survival.readthedocs.io/en/stable/index.html
- https://scikit-survival.readthedocs.io/en/stable/user_guide/coxnet.html

In [1]:
import pandas as pd
import numpy as np


from sksurv.datasets import load_breast_cancer
from sksurv.linear_model import CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis
# from sksurv.preprocessing import OneHotEncoder
from sklearn.preprocessing import OneHotEncoder

from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [59]:
del(X)

In [60]:
X_raw, y = load_breast_cancer()

X_raw.shape, y.shape

((198, 80), (198,))

In [61]:
X_raw.head()

Unnamed: 0,X200726_at,X200965_s_at,X201068_s_at,X201091_s_at,X201288_at,X201368_at,X201663_s_at,X201664_at,X202239_at,X202240_at,...,X221344_at,X221634_at,X221816_s_at,X221882_s_at,X221916_at,X221928_at,age,er,grade,size
0,10.926361,8.962608,11.630078,10.964107,11.518305,12.038527,9.623518,9.814798,10.016732,7.847383,...,6.313081,7.842097,10.132635,10.926365,6.477749,5.991885,57.0,negative,poorly differentiated,3.0
1,12.24209,9.531718,12.626106,11.594716,12.317659,10.776911,10.604577,10.704329,10.161838,8.744875,...,5.126765,8.780328,10.213467,9.555092,4.96805,7.05113,57.0,positive,poorly differentiated,3.0
2,11.661716,10.23868,12.572919,9.166088,11.698658,11.353333,9.384927,10.161654,10.032721,8.125487,...,6.936022,7.855649,10.164514,9.308048,4.283777,6.828986,48.0,negative,poorly differentiated,2.5
3,12.174021,9.819279,12.109888,9.086937,13.132617,11.859394,8.400839,8.670721,10.727427,8.65081,...,6.787297,6.678375,10.660092,10.208241,5.713404,6.927251,42.0,positive,poorly differentiated,1.8
4,11.484011,11.489233,11.779285,8.887616,10.429663,11.401139,7.741092,8.642018,9.556686,8.478862,...,7.312287,7.358556,11.57033,10.931843,5.817265,6.655448,46.0,positive,intermediate,3.0


In [67]:
# sparse category 'unk(n)own'
X_raw.grade.value_counts()

intermediate             83
poorly differentiated    83
well differentiated      30
unkown                    2
Name: grade, dtype: int64

In [66]:
X_raw.er.value_counts()

positive    134
negative     64
Name: er, dtype: int64

# Preproc

In [68]:
# use fewer columns so that non-regularized model still fits
# selected_columns = list(X.columns[:2].append(X.columns[76:]))
# selected_columns = X.columns[76:]

In [92]:
X.loc[:, X_raw.dtypes != 'category']

Unnamed: 0,X200726_at,X200965_s_at,X201068_s_at,X201091_s_at,X201288_at,X201368_at,X201663_s_at,X201664_at,X202239_at,X202240_at,...,X221028_s_at,X221241_s_at,X221344_at,X221634_at,X221816_s_at,X221882_s_at,X221916_at,X221928_at,age,size
0,10.926361,8.962608,11.630078,10.964107,11.518305,12.038527,9.623518,9.814798,10.016732,7.847383,...,7.129340,5.649573,6.313081,7.842097,10.132635,10.926365,6.477749,5.991885,57.0,3.0
1,12.242090,9.531718,12.626106,11.594716,12.317659,10.776911,10.604577,10.704329,10.161838,8.744875,...,7.189642,7.599788,5.126765,8.780328,10.213467,9.555092,4.968050,7.051130,57.0,3.0
2,11.661716,10.238680,12.572919,9.166088,11.698658,11.353333,9.384927,10.161654,10.032721,8.125487,...,7.222765,4.987613,6.936022,7.855649,10.164514,9.308048,4.283777,6.828986,48.0,2.5
3,12.174021,9.819279,12.109888,9.086937,13.132617,11.859394,8.400839,8.670721,10.727427,8.650810,...,6.584748,7.205051,6.787297,6.678375,10.660092,10.208241,5.713404,6.927251,42.0,1.8
4,11.484011,11.489233,11.779285,8.887616,10.429663,11.401139,7.741092,8.642018,9.556686,8.478862,...,8.052990,6.973316,7.312287,7.358556,11.570330,10.931843,5.817265,6.655448,46.0,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
193,12.018292,8.323876,11.955274,10.740020,11.150428,10.650873,8.787549,9.747182,10.176306,9.240307,...,8.055945,8.542533,6.354728,9.240358,10.074577,9.604668,6.889179,5.892040,39.0,2.2
194,11.711415,10.428482,12.420877,11.145993,11.084685,11.169750,10.870530,11.128882,9.573702,9.287522,...,8.391159,6.457199,6.481257,7.645481,11.289674,11.067820,5.393927,6.421038,46.0,3.2
195,11.939616,9.615587,11.962812,10.463171,11.514539,11.487394,10.443569,11.104227,9.051649,7.279063,...,7.029631,7.138855,6.750251,7.642915,11.490971,7.743568,5.922925,6.979894,47.0,2.5
196,11.848449,10.528911,11.318453,8.609631,13.719035,12.909814,7.525994,8.255546,9.788903,7.343499,...,7.570143,8.603733,7.134177,7.536181,11.092211,9.298178,7.278383,7.195797,43.0,1.2


In [70]:
X = X_raw.drop('grade', axis=1)

# MODEL

In [93]:
INNER_CV = 5
OUTER_CV = 5
RANDOM_STATE = 99

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/OUTER_CV,
                                                    shuffle=True,
                                                    random_state=RANDOM_STATE)
ohe = make_column_transformer(
    (OneHotEncoder(sparse=False, drop='first', handle_unknown='error'),
     make_column_selector(dtype_include='category')),
    remainder='passthrough'
)

In [94]:
X_train_t = ohe.fit_transform(X_train)
X_test_t = ohe.transform(X_test)

In [116]:
cph = CoxPHSurvivalAnalysis(alpha=0) # default
cph.fit(X_train_t, y_train)

  loss -= (numerator - n_events * numpy.log(risk_set)) / n_samples
  z = risk_set_x / risk_set
  a = risk_set_xx / risk_set


ValueError: search direction contains NaN or infinite values

In [138]:
TOL = 1e-7

In [139]:
# (tol is not the same as Coxnet)
cph = CoxPHSurvivalAnalysis(alpha=0.01, tol=TOL)
cph.fit(X_train_t, y_train)

cph.coef_.flatten().round(2)

array([-16.62,   6.25,   1.5 ,   2.57,  -4.97,  -7.3 ,  -8.37,  -7.33,
        11.4 ,   3.16,   4.75,  -4.53,  -9.11, -11.08, -10.01,  -2.27,
         4.44,  -0.77,  -3.71,  -2.9 ,   4.61,  -1.51,   4.48,  -0.55,
         2.3 ,  -0.2 ,  -1.34,   6.54,   5.7 ,   2.8 ,   0.7 ,   5.9 ,
         5.84,  -2.35,   5.17,   2.74,  -1.16, -13.23,  -1.49,  -0.71,
        -3.6 ,  -4.41,   6.21,   3.73,   1.88,  -3.05,  -1.19, -11.49,
        -3.5 ,  -0.72,  -3.22,  -2.93,   1.15,  -4.27,  -1.64,   7.41,
         1.6 ,  -1.95,  -2.34,  -2.03,   5.57,  -5.53,   2.47,  17.86,
        -2.05,   1.35,  -4.4 ,  -2.26,   5.96,   2.34,   2.31,   1.45,
        -5.46,  -5.09,   0.53,  -2.11,  -5.46,   0.21,   1.88])

In [141]:
cph = CoxnetSurvivalAnalysis(alphas=[0.01], l1_ratio=.000001, tol=TOL)
cph.fit(X_train_t, y_train)

cph.coef_.flatten().round(2)

  cph.fit(X_train_t, y_train)


array([ -2.34,  -0.65,  -8.3 ,  -1.53,   1.68,  -9.94,  -9.79,  -0.04,
        -0.86,  -4.02,  11.04,   2.66, -10.92,  -7.64,  -6.41, -14.23,
        -5.91,   2.86,   0.48,  23.24,   0.53,  -3.3 ,  -0.7 ,   1.62,
         4.57,   0.49,  -9.61,  10.8 ,  24.47,  -5.22,  -0.63,   7.09,
         2.42,  -7.25,  -4.97,  -1.52,   7.43,   3.1 ,  -4.69,   7.08,
        -2.35,   1.91,   0.65,   2.09,   4.89,  -3.64,   3.46,  -5.23,
        -5.16, -16.19,   6.79,  -1.95,   6.79,   7.91,   0.89, -10.15,
        -6.29,  -1.16,  -3.07,  -1.93,   5.82,   0.84,   7.88,   1.57,
        -0.04,   2.62,   1.46, -11.14,   5.57,   6.4 , -12.35,   1.62,
         2.58,  -7.36,   8.56,  -3.33,  -4.54,  -6.44,   7.57])