Skip to content

Commit

Permalink
ENH More efficient B.dot and B.T.dot in Cox datafit (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
Badr-MOUFAD committed Jun 20, 2023
1 parent e7048b6 commit 189d21e
Showing 1 changed file with 79 additions and 41 deletions.
120 changes: 79 additions & 41 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,19 +561,24 @@ class Cox(BaseDatafit):
Attributes
----------
B : array-like, shape (n_samples, n_samples)
Matrix where every ``(i, j)`` entry (row, column) equals ``1``
if ``tm[j] >= tm[i]`` and ``0`` otherwise. This matrix is initialized
using the ``.initialize`` method.
T_indices : array-like, shape (n_samples,)
Indices of observations with the same occurrence times stacked horizontally as
``[group_1, group_2, ...]`` in ascending order. It is initialized
with the ``.initialize`` method (or ``initialize_sparse`` for sparse ``X``).
H_indices : array-like, shape (n_samples,)
Indices of observations with the same occurrence times stacked horizontally
as ``[group_1, group_2, ...]``. This array is initialized
when calling ``.initialize`` method when ``use_efron=True``.
T_indptr : array-like, (np.unique(tm) + 1,)
Array where two consecutive elements delimit a group of
observations having the same occurrence times.
H_indptr : array-like, (np.unique(tm) + 1,)
Array where two consecutive elements delimits a group of observations
having the same occurrence times.
H_indices : array-like, shape (n_samples,)
Indices of uncensored observations with the same occurrence times stacked
horizontally as ``[group_1, group_2, ...]`` in ascending order.
It is initialized when calling the ``.initialize`` method
(or ``initialize_sparse`` for sparse ``X``) when ``use_efron=True``.
H_indptr : array-like, shape (np.unique(tm[s != 0]) + 1,)
Array where two consecutive elements delimits a group of uncensored
observations having the same occurrence time.
"""

def __init__(self, use_efron=False):
Expand All @@ -582,9 +587,8 @@ def __init__(self, use_efron=False):
def get_spec(self):
return (
('use_efron', bool_),
('B', float64[:, ::1]),
('H_indptr', int64[:]),
('H_indices', int64[:]),
('T_indptr', int64[:]), ('T_indices', int64[:]),
('H_indptr', int64[:]), ('H_indices', int64[:]),
)

def params_to_dict(self):
Expand All @@ -597,7 +601,7 @@ def value(self, y, w, Xw):

# compute inside log term
exp_Xw = np.exp(Xw)
B_exp_Xw = self.B @ exp_Xw
B_exp_Xw = self._B_dot_vec(exp_Xw)
if self.use_efron:
B_exp_Xw -= self._A_dot_vec(exp_Xw)

Expand All @@ -614,12 +618,12 @@ def raw_grad(self, y, Xw):
n_samples = Xw.shape[0]

exp_Xw = np.exp(Xw)
B_exp_Xw = self.B @ exp_Xw
B_exp_Xw = self._B_dot_vec(exp_Xw)
if self.use_efron:
B_exp_Xw -= self._A_dot_vec(exp_Xw)

s_over_B_exp_Xw = s / B_exp_Xw
out = -s + exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
out = -s + exp_Xw * self._B_T_dot_vec(s_over_B_exp_Xw)
if self.use_efron:
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)

Expand All @@ -635,12 +639,12 @@ def raw_hessian(self, y, Xw):
n_samples = Xw.shape[0]

exp_Xw = np.exp(Xw)
B_exp_Xw = self.B @ exp_Xw
B_exp_Xw = self._B_dot_vec(exp_Xw)
if self.use_efron:
B_exp_Xw -= self._A_dot_vec(exp_Xw)

s_over_B_exp_Xw = s / B_exp_Xw
out = exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
out = exp_Xw * self._B_T_dot_vec(s_over_B_exp_Xw)
if self.use_efron:
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)

Expand All @@ -654,38 +658,53 @@ def initialize(self, X, y):
"""Initialize the datafit attributes."""
tm, s = y

tm_as_col = tm.reshape((-1, 1))
self.B = (tm >= tm_as_col).astype(X.dtype)
self.T_indices = np.argsort(tm)
self.T_indptr = self._get_indptr(tm, self.T_indices)

if self.use_efron:
H_indices = np.argsort(tm)
# filter out censored data
H_indices = H_indices[s[H_indices] != 0]
n_uncensored_samples = H_indices.shape[0]

# build H_indptr
H_indptr = [0]
count = 1
for i in range(1, n_uncensored_samples):
if tm[H_indices[i-1]] == tm[H_indices[i]]:
count += 1
else:
H_indptr.append(count + H_indptr[-1])
count = 1
H_indptr.append(n_uncensored_samples)
H_indptr = np.asarray(H_indptr, dtype=np.int64)

# save in instance
self.H_indptr = H_indptr
self.H_indices = H_indices
self.H_indices = self.T_indices[s[self.T_indices] != 0]
self.H_indptr = self._get_indptr(tm, self.H_indices)

def initialize_sparse(self, X_data, X_indptr, X_indices, y):
"""Initialize the datafit attributes in sparse dataset case."""
# initialize_sparse and initialize have the same implementation
# `initialize_sparse` and `initialize` have the same implementation
# small hack to avoid repetitive code: pass in X_data as only its dtype is used
self.initialize(X_data, y)

def _B_dot_vec(self, vec):
# compute `B @ vec` in O(n) instead of O(n^2)
out = np.zeros_like(vec)
n_T = self.T_indptr.shape[0] - 1
cum_sum = 0.

# reverse loop to avoid starting from cum_sum and subtracting vec coordinates
# subtracting big numbers results in 'cancellation errors' and hence erroneous
# results. Ref. J Nocedal, "Numerical optimization", page 615
for idx in range(n_T - 1, -1, -1):
current_T_idx = self.T_indices[self.T_indptr[idx]: self.T_indptr[idx+1]]

cum_sum += np.sum(vec[current_T_idx])
out[current_T_idx] = cum_sum

return out

def _B_T_dot_vec(self, vec):
# compute `B.T @ vec` in O(n) instead of O(n^2)
out = np.zeros_like(vec)
n_T = self.T_indptr.shape[0] - 1
cum_sum = 0.

for idx in range(n_T):
current_T_idx = self.T_indices[self.T_indptr[idx]: self.T_indptr[idx+1]]

cum_sum += np.sum(vec[current_T_idx])
out[current_T_idx] = cum_sum

return out

def _A_dot_vec(self, vec):
# compute `A @ vec` in O(n) instead of O(n^2)
out = np.zeros_like(vec)
n_H = self.H_indptr.shape[0] - 1

Expand All @@ -700,6 +719,7 @@ def _A_dot_vec(self, vec):
return out

def _AT_dot_vec(self, vec):
# compute `A.T @ vec` in O(n) instead of O(n^2)
out = np.zeros_like(vec)
n_H = self.H_indptr.shape[0] - 1

Expand All @@ -712,3 +732,21 @@ def _AT_dot_vec(self, vec):
out[current_H_idx] = weighted_sum_vec_H * np.ones(size_current_H)

return out

def _get_indptr(self, vals, indices):
# given `indices = argsort(vals)`
# build and array `indptr` where two consecutive elements
# delimit indices with the same val
n_indices = indices.shape[0]

indptr = [0]
count = 1
for i in range(n_indices - 1):
if vals[indices[i]] == vals[indices[i+1]]:
count += 1
else:
indptr.append(count + indptr[-1])
count = 1
indptr.append(n_indices)

return np.asarray(indptr, dtype=np.int64)

0 comments on commit 189d21e

Please sign in to comment.