In [1]:
import numpy as np
import xarray as xr

from lifelines.datasets import load_kidney_transplant

from dask.distributed import Client, LocalCluster
import dask.array as da


In [2]:
df = load_kidney_transplant()

# let us consider two covariates
cols = ["afb", "mitype"]

#df = df[cols].rename(columns={cols[0]: "v1", cols[1]: "v2"}).astype(float)
X_np = df[['black_male','white_male','black_female']].to_numpy().astype(np.float64)
time_np = df['time'].to_numpy().astype(np.float64)
event_np = df['death'].to_numpy().astype(np.float64)


In [3]:
cluster = LocalCluster()
client = Client(cluster)

In [4]:
X = da.from_array(X_np,chunks=(100,-1)).persist()
time = da.from_array(time_np,chunks=(100)).persist()
event = da.from_array(event_np,chunks=(100)).persist()

unique_times, time_return_inverse =  da.unique(time,return_inverse=True)
unique_times = unique_times.compute()
n_unique_times = len(unique_times)

In [5]:
weights = np.zeros(X.shape[1])

In [6]:
def reverse_cumsum(a):
    return da.flip(da.cumsum(da.flip(a)))


p = da.dot(X,weights)
p_exp= da.exp(p)
risk_set = reverse_cumsum(da.bincount(time_return_inverse,weights= p_exp,minlength=len(unique_times)))[time_return_inverse]
loss = - np.sum(event * (p - np.log(risk_set)))


In [7]:
loss.compute() #the loss is correct

879.1848811180897

In [8]:
XxXb = np.multiply(X,p_exp[:,np.newaxis])
XxXb_at_Xt_at_time = xr.DataArray(XxXb).groupby(xr.DataArray(time_return_inverse)).sum()
XxXb_at_Xt_at_time_cumsum = da.apply_along_axis(reverse_cumsum,0,XxXb_at_Xt_at_time)
#this bottom line is the issue, as dask is not properlly indexing/tacking as a numpy array would
#XxXb_at_Xt_at_index = XxXb_at_Xt_at_time_cumsum[time_return_inverse]
XxXb_at_Xt_at_index = da.take(XxXb_at_Xt_at_time_cumsum,time_return_inverse,axis=0)
#XxXb_at_Xt_at_index = da.apply_along_axis(lambda a: a.take(time_return_inverse),0,XxXb_at_Xt_at_time_cumsum)


In [9]:
jacobian = -da.sum(event[:,np.newaxis] * (X - XxXb_at_Xt_at_index/risk_set[:,np.newaxis]),axis=0)


In [10]:
correct_jacobian = np.array([ 0.51671724, -3.7450042 , -5.16354462])

In [11]:
jacobian_c = jacobian.compute()

#jacobian is incorrrect :(
print(jacobian_c,correct_jacobian)

[ -0.84453129 -10.58125143  -6.02887283] [ 0.51671724 -3.7450042  -5.16354462]


In [12]:
#here I demonstrate that the issue is with indexing XxXb_at_Xt_at_time_cumsum by time_return_inverse

def reverse_cumsum_np(a):
    return np.flip(np.cumsum(np.flip(a)))

XxXb_np = XxXb.compute()

unique_times_np, time_return_inverse_np =  np.unique(time_np,return_inverse=True)
n_unique_times_np = len(unique_times)


XxXb_at_Xt_at_time_np = np.apply_along_axis(lambda a: np.bincount(time_return_inverse_np,weights=a,minlength=n_unique_times_np),0,XxXb_np)
XxXb_at_Xt_at_time_cumsum_np = np.apply_along_axis(reverse_cumsum_np,0,XxXb_at_Xt_at_time_np)
XxXb_at_Xt_at_index_np = XxXb_at_Xt_at_time_cumsum_np[time_return_inverse_np]


In [13]:
#XxXb_at_Xt_at_time_cumsum matches

np.testing.assert_almost_equal( XxXb_at_Xt_at_time_cumsum_np, XxXb_at_Xt_at_time_cumsum.compute() )

#but XxXb_at_Xt_at_index does not match, the issue is with a bug with dask indexing/take

print( XxXb_at_Xt_at_index_np,XxXb_at_Xt_at_index.compute())


[[ 92. 432.  59.]
 [ 92. 431.  59.]
 [ 92. 430.  59.]
 ...
 [  3.  15.   3.]
 [  3.  11.   2.]
 [  1.   5.   1.]] [[ 92. 432.  59.]
 [ 92. 432.  59.]
 [ 92. 432.  59.]
 ...
 [  1.   3.   0.]
 [  1.   3.   0.]
 [  1.   3.   0.]]


In [14]:
#on to the hessian matrix
def three_dimensional_groupby_sum(array,by):
    result = xr.DataArray(array).groupby(xr.DataArray(by)).sum()
    index_id = list(result.indexes.dims)[0]
    index = result.indexes[index_id]
    return da.from_array(result)[index]

def _rechunk_for_3d_cumsum(a,element_size,rows_per_chunk=100):
    return  a.rechunk(rows_per_chunk,element_size[0],element_size[0])

def dask_3d_cumsum(a,element_size):
    return  da.reductions.cumreduction(np.add.accumulate,lambda a, b: a+b,np.zeros(element_size),axis=0,dtype=np.float64,x=a) 

In [15]:
X2xXb = np.einsum("ij,ik,i->ijk", X, X, p_exp)
#X2xXb_at_time = three_dimensional_groupby_sum(X2xXb,time_return_inverse)
X2xXb_at_time = three_dimensional_groupby_sum(X2xXb,time_return_inverse)

element_size = (X.shape[1],X.shape[1])
X2xXb_at_time = _rechunk_for_3d_cumsum(X2xXb_at_time,element_size)
X2Xb_at_Xt_at_index = da.flip(dask_3d_cumsum(da.flip( X2xXb_at_time),element_size))[time_return_inverse]

a = X2Xb_at_Xt_at_index/risk_set[:,None,None]
b = np.matmul(XxXb_at_Xt_at_index[:,:,None], XxXb_at_Xt_at_index[:,None,:])/(risk_set**2)[:,None,None]
c = a - b

hessian = np.sum(event[:,None,None] * c,axis=0)


In [16]:
#the hessian is also incorrect, :(
#the corrct values would have been
# array([[13.00513315, -7.18493076, -0.92061414],
#        [-7.18493076, 34.95695067, -4.37961393],
#        [-0.92061414, -4.37961393,  8.27053441]])
#T belive this is also due to the dask sorting issue that I had with the jacobian

hessian.compute()

array([[12.78620278, -8.22528016, -1.05192869],
       [-8.22528016, 30.10371835, -4.99783531],
       [-1.05192869, -4.99783531,  8.19314371]])