In [1]:
import numpy as np

from lifelines.datasets import load_kidney_transplant

from scipy.optimize import minimize

from sksurv.datasets import load_whas500
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.util import Surv

import numba as nb
from numba import njit


In [2]:
df = load_kidney_transplant()

X = df[['age','black_male','white_male','black_female']].to_numpy().astype(np.float64)
time = df['time'].to_numpy().astype(np.float64)
event = df['death'].to_numpy().astype(np.float64)

def normalize(X):
    return np.subtract(X , X.mean(axis=0))/X.std(axis=0)

X = normalize(X)

unique_times, time_return_inverse =  np.unique(time,return_inverse=True)
n_unique_times = len(unique_times)

In [3]:
def reverse_cumsum(a):
    return np.flip(np.cumsum(np.flip(a)))

# import xarray as xr

# def three_dimensional_groupby_sum(array,by):
#     result = xr.DataArray(array).groupby(xr.DataArray(by)).sum()
#     return result.values[result.group.values]

@njit(
nb.types.Array(nb.float64,3,'A', False, aligned=True)(
    nb.types.Array(nb.float64, 3, 'A', False, aligned=True),
    nb.types.Array(nb.int64, 1, 'C', False, aligned=True),
    nb.int64)
)
def three_dimensional_groupby_sum(array,by,n_unique_times):
    output = np.zeros((n_unique_times, array.shape[1],  array.shape[2]))

    for i in range(by.shape[0]):
        by_i = by[i]
        array_i = array[i]
        output[by_i] += array_i
        
    return output

In [4]:
def get_n_log_likelihood_loss_jacobian_hessian(weights,X,event,time_return_inverse,n_unique_times):

    p = np.dot(X,weights)
    p_exp= np.exp(p)
    risk_set = reverse_cumsum(np.bincount(time_return_inverse,weights= p_exp))[time_return_inverse]

    loss = -  np.sum(event * (p - np.log(risk_set)))
    
    XxXb = np.multiply(X,p_exp[:,np.newaxis])
    XxXb_at_Xt_at_time = np.apply_along_axis(lambda a: np.bincount(time_return_inverse,weights=a,minlength=n_unique_times),0,XxXb)
    XxXb_at_Xt_at_time_cumsum = np.apply_along_axis(reverse_cumsum,0,XxXb_at_Xt_at_time)
    XxXb_at_Xt_at_index = XxXb_at_Xt_at_time_cumsum[time_return_inverse]

    jacobian= -np.sum(event[:,np.newaxis] * (X - XxXb_at_Xt_at_index/risk_set[:,np.newaxis]),axis=0)
    
    X2xXb = np.einsum("ij,ik,i->ijk", X, X, p_exp)
    #X2xXb_at_time = three_dimensional_groupby_sum(X2xXb,time_return_inverse)
    X2xXb_at_time = three_dimensional_groupby_sum(X2xXb,time_return_inverse,n_unique_times)
    X2Xb_at_Xt_at_index = np.flip(np.add.accumulate(np.flip( X2xXb_at_time)))[time_return_inverse]
    
    a = X2Xb_at_Xt_at_index/risk_set[:,None,None]
    b = np.matmul(XxXb_at_Xt_at_index[:,:,None], XxXb_at_Xt_at_index[:,None,:])/(risk_set**2)[:,None,None]
    c = a - b

    hessian = np.sum(event[:,None,None] * c,axis=0)

    return loss, jacobian, hessian
    

In [5]:
def train_weights_for_cox_ph_breslow(X,event,n_unique_times,time_return_inverse, max_itterations = 100, loss_jacobian_hessian_function=get_n_log_likelihood_loss_jacobian_hessian):
    #https://myweb.uiowa.edu/pbreheny/7210/f15/notes/10-27.pdf 
    #according to Dr.Breheny's notes, one should start halfsteping Newton-Raphson for cox when one starts having touble traning, before terminating the training loop
    #"Supposedly" R's survival package does this

    weights = np.zeros(X.shape[1])

    last_loss = np.array(np.inf)
    
    half_step = False
    
    for i in range(max_itterations):
        loss, jacobian, hessian = loss_jacobian_hessian_function(weights,X,event,time_return_inverse,n_unique_times)
        if (loss < last_loss) &  (not half_step):
            last_loss = loss
            weights = weights -  np.dot(np.linalg.inv(hessian),jacobian)
        elif (loss < last_loss) & half_step:
            last_loss = loss
            weights = weights - (0.5 * np.dot(np.linalg.inv(hessian),jacobian))
        else:
            if half_step:
                break
            else:
                half_step = True

    return weights


In [6]:
hand_trained_cox = train_weights_for_cox_ph_breslow(X,event,n_unique_times,time_return_inverse)
print(hand_trained_cox)

[ 0.68441213 -0.01099881  0.04867061  0.1145396 ]


In [9]:
y_sur = Surv().from_arrays(event,time)
sscox_coef = CoxPHSurvivalAnalysis(ties='breslow').fit(X ,y_sur).coef_
print(sscox_coef)

[ 0.68441213 -0.01099881  0.04867061  0.1145396 ]


In [10]:
np.testing.assert_almost_equal(hand_trained_cox,sscox_coef)
#weights seem to match, lets do a speed test

In [14]:
%timeit train_weights_for_cox_ph_breslow(X,event,n_unique_times,time_return_inverse)

3.84 ms ± 293 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit  CoxPHSurvivalAnalysis(ties='breslow').fit(X ,y_sur)

58.3 ms ± 317 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
