Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - Implement Cox with Efron estimate #159

Merged
merged 52 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f2b2ae7
init commit
Badr-MOUFAD May 17, 2023
5e4fa49
implem datafit
Badr-MOUFAD May 23, 2023
2d72326
fix numba errors && unittest datafit
Badr-MOUFAD May 23, 2023
1e8ea39
normalize df with ``n_samples``
Badr-MOUFAD May 23, 2023
85d3766
unittest Cox Estimator against lifeline
Badr-MOUFAD May 23, 2023
6a37aaa
debug script
Badr-MOUFAD May 23, 2023
7a95dca
avoid ties
Badr-MOUFAD May 23, 2023
1e7083e
finding 0 or very small solution even for reg < 1
mathurinm May 24, 2023
98999d1
fix unittest: agree with ``lifeline``
Badr-MOUFAD May 24, 2023
2541596
require ``lifelines`` in CI tests
Badr-MOUFAD May 24, 2023
8e44a48
Cox docs
Badr-MOUFAD May 24, 2023
24c20e2
more on docs cox
Badr-MOUFAD May 24, 2023
22f8461
normalize as external param
Badr-MOUFAD May 24, 2023
e8b1c1f
dummy survival data docs
Badr-MOUFAD May 24, 2023
676e3ab
fix pydoctest
Badr-MOUFAD May 24, 2023
8fb0063
faster matmul && fix lifelines install in CI
Badr-MOUFAD May 24, 2023
77dfa65
preserve support of ``numba v0.56``
Badr-MOUFAD May 24, 2023
ce0631f
make script debug reproducible
mathurinm May 25, 2023
088818b
illustrate convergence failure of lifelines
mathurinm May 25, 2023
30e47e3
fix
mathurinm May 25, 2023
4fe40fa
clean up
Badr-MOUFAD May 27, 2023
1fc5e4f
add support of sparse data
Badr-MOUFAD May 27, 2023
8c3b657
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD May 27, 2023
308790b
use Weibull for ``tm``
Badr-MOUFAD May 27, 2023
2aac108
unittest Cox sparse data
Badr-MOUFAD May 27, 2023
9d47207
clean ups
Badr-MOUFAD May 27, 2023
329ef37
setups efron
Badr-MOUFAD May 30, 2023
547c566
compute val
Badr-MOUFAD May 30, 2023
ca12fb4
implement grad and Hessian
Badr-MOUFAD May 30, 2023
0670baa
implement A and A.T dot ops
Badr-MOUFAD May 30, 2023
1e68ec0
filter out censored data
Badr-MOUFAD May 30, 2023
842dd66
fix shapes bugs
Badr-MOUFAD May 30, 2023
5db6552
fix dtype & fraction
Badr-MOUFAD May 30, 2023
f0139c6
unittest Efron datafit
Badr-MOUFAD May 30, 2023
a899855
unittest efron estimator
Badr-MOUFAD May 30, 2023
b9786c5
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD May 31, 2023
c44eb66
rm cox in api
Badr-MOUFAD May 31, 2023
1122c21
typo docs Breslow
Badr-MOUFAD May 31, 2023
233ec7f
Efron for sparse data
Badr-MOUFAD May 31, 2023
7d22b44
add argument ``with_ties`` in dummy data
Badr-MOUFAD May 31, 2023
5fb9486
sample data from weibull
Badr-MOUFAD May 31, 2023
e97b724
update docs
Badr-MOUFAD Jun 7, 2023
9e0a0cc
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Jun 7, 2023
2efe6df
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Jun 8, 2023
7149ac6
fix links to docs
Badr-MOUFAD Jun 8, 2023
3bd41d7
typos
Badr-MOUFAD Jun 8, 2023
70bb89c
example lifelines: data and compare sols
Badr-MOUFAD Jun 8, 2023
8a8eaab
example lifelines: speed up ratio
Badr-MOUFAD Jun 8, 2023
e2147d0
example lifelines: check ties
Badr-MOUFAD Jun 8, 2023
57d46ce
example lifelines: typos and reformulations
Badr-MOUFAD Jun 8, 2023
5bb4313
example lifelines: fix heading
Badr-MOUFAD Jun 8, 2023
8e095c6
fix format
Badr-MOUFAD Jun 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions examples/plot_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,74 @@
ax.set_ylabel("objective suboptimality")
_ = ax.set_xlabel("time in seconds")



# %%
# According to printed ratio, using ``skglm`` we get the same result as ``lifelines``
# with more than x100 less time!
speed_up = records["lifelines"]["times"][-1] / records["skglm"]["times"][-1]
print(f"speed up ratio: {speed_up:.0f}")

# %%
# Efron estimate
# ==============
#
# The previous results, namely closeness and timings, can be extended to the case
# of handling tied observation with the Efron estimate.
#
# Let's start by generating data with tied observation. This can be achieved
# by passing in a ``with_ties=True`` to ``make_dummy_survival_data`` function.
tm, s, X = make_dummy_survival_data(
n_samples, n_features,
normalize=True,
with_ties=True,
random_state=0
)

# check data has tied observations
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
print("Number of unique times", np.unique(tm))
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
print("Number of samples", n_samples)

# %%
# It is straightforward to fit an :math:`\ell_1` Cox estimator with the Efron estimate.
# We only need to pass in ``use_efron=True`` to the ``Cox`` datafit.

# ensure using Efron estimate
mathurinm marked this conversation as resolved.
Show resolved Hide resolved
datafit.use_efron = True

# re init datafit to consider Efron estimate and the new dataset
datafit.initialize(X, (tm, s))

# solve the problem
w_sk = solver.solve(X, (tm, s), datafit, penalty)[0]

# %%
# Again a relatively sparse solution is found:
print(
"Number of nonzero coefficients in solution: "
f"{(w_sk != 0).sum()} out of {len(w_sk)}."
)

# %%
# Let's do the same with ``lifelines`` and compare results

# format data
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
df = pd.DataFrame(stacked_tm_s_X)

# fit lifelines estimator on the new data
lifelines_estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.).fit(
df,
duration_col=0,
event_col=1
)
w_ll = lifelines_estimator.params_.values

# Check that both solvers find solutions with the same objective value
obj_sk = datafit.value((tm, s), w_sk, X @ w_sk) + penalty.value(w_sk)
obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)

print(f"Objective skglm: {obj_sk:.6f}")
print(f"Objective lifelines: {obj_ll:.6f}")
print(f"Difference: {(obj_sk - obj_ll):.2e}")

# Check that both solutions are close
print(f"Euclidean distance between solutions: {np.linalg.norm(w_sk - w_ll):.3e}")
139 changes: 99 additions & 40 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from numpy.linalg import norm
from numba import njit
from numba import float64
from numba import float64, int64, bool_

from skglm.datafits.base import BaseDatafit
from skglm.utils.sparse_ops import spectral_norm
Expand Down Expand Up @@ -547,90 +547,100 @@ def intercept_update_self(self, y, Xw):


class Cox(BaseDatafit):
r"""Cox datafit for survival analysis with Breslow estimate.
r"""Cox datafit for survival analysis.

The datafit reads [1]

.. math::

1 / n_"samples" \sum_(i=1)^(n_"samples") -s_i \langle x_i, w \rangle
+ \log (\sum_(j | y_j \geq y_i) e^{\langle x_i, w \rangle})

where :math:`s_i` indicates the sample censorship and :math:`tm`
is the vector recording the time of event occurrences.

Defining the matrix :math:`B` with
:math:`B_{i,j} = 1` if :math:`tm_j \geq tm_i` and :math:`0` otherwise,
the datafit can be rewritten in the following compact form

.. math::

1 / n_"samples" \langle s, Xw \rangle
+ 1 / n_"samples" \langle s, \log B e^{Xw} \rangle
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>` for details.

Parameters
----------
use_efron : bool, default=False
If ``True`` uses Efron estimate to handle tied observations.

Attributes
----------
B : array-like, shape (n_samples, n_samples)
Matrix where every ``(i, j)`` entry (row, column) equals ``1``
if ``tm[j] >= tm[i]`` and `0` otherwise. This matrix is initialized
if ``tm[j] >= tm[i]`` and ``0`` otherwise. This matrix is initialized
using the ``.initialize`` method.

References
----------
.. [1] DY Lin. On the Breslow estimator.
Lifetime data analysis, 13:471–480, 2007.
H_indices : array-like, shape (n_samples,)
Indices of observations with the same occurrence times stacked horizontally
as ``[group_1, group_2, ...]``. This array is initialized
when calling ``.initialize`` method when ``use_efron=True``.

H_indptr : array-like, (np.unique(tm) + 1,)
Array where two consecutive elements delimits a group of observations
having the same occurrence times.
"""

def __init__(self):
pass
def __init__(self, use_efron=False):
self.use_efron = use_efron

def get_spec(self):
return (
('use_efron', bool_),
('B', float64[:, ::1]),
('H_indptr', int64[:]),
('H_indices', int64[:]),
)

def params_to_dict(self):
return dict()
return dict(use_efron=self.use_efron)

def value(self, y, w, Xw):
"""Compute the value of the datafit."""
tm, s = y
n_samples = Xw.shape[0]

out = -(s @ Xw) + s @ np.log(self.B @ np.exp(Xw))
# compute inside log term
exp_Xw = np.exp(Xw)
B_exp_Xw = self.B @ exp_Xw
if self.use_efron:
B_exp_Xw -= self._A_dot_vec(exp_Xw)

out = -(s @ Xw) + s @ np.log(B_exp_Xw)
return out / n_samples

def raw_grad(self, y, Xw):
r"""Compute gradient of datafit w.r.t. ``Xw``.

The raw gradient reads

(-s + exp_Xw * (B.T @ (s / B @ exp_Xw)) / n_samples
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
equation 4 for details.
"""
tm, s = y
n_samples = Xw.shape[0]

exp_Xw = np.exp(Xw)
B_exp_Xw = self.B @ exp_Xw
if self.use_efron:
B_exp_Xw -= self._A_dot_vec(exp_Xw)

s_over_B_exp_Xw = s / B_exp_Xw
out = -s + exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
if self.use_efron:
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)

out = -s + exp_Xw * (self.B.T @ (s / B_exp_Xw))
return out / n_samples

def raw_hessian(self, y, Xw):
"""Compute a diagonal upper bound of the datafit's Hessian w.r.t. ``Xw``.

The diagonal upper bound reads

exp_Xw * (B.T @ s / B_exp_Xw) / n_samples
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
equation 6 for details.
"""
tm, s = y
n_samples = Xw.shape[0]

exp_Xw = np.exp(Xw)
B_exp_Xw = self.B @ exp_Xw
if self.use_efron:
B_exp_Xw -= self._A_dot_vec(exp_Xw)

s_over_B_exp_Xw = s / B_exp_Xw
out = exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
if self.use_efron:
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)

out = exp_Xw * (self.B.T @ (s / B_exp_Xw))
return out / n_samples

def initialize(self, X, y):
Expand All @@ -640,9 +650,58 @@ def initialize(self, X, y):
tm_as_col = tm.reshape((-1, 1))
self.B = (tm >= tm_as_col).astype(X.dtype)

if self.use_efron:
H_indices = np.argsort(tm)
# filter out censored data
H_indices = H_indices[s[H_indices] != 0]
n_uncensored_samples = H_indices.shape[0]

# build H_indptr
H_indptr = [0]
count = 1
for i in range(1, n_uncensored_samples):
if tm[H_indices[i-1]] == tm[H_indices[i]]:
count += 1
else:
H_indptr.append(count + H_indptr[-1])
count = 1
H_indptr.append(n_uncensored_samples)
H_indptr = np.asarray(H_indptr, dtype=np.int64)

# save in instance
self.H_indptr = H_indptr
self.H_indices = H_indices

def initialize_sparse(self, X_data, X_indptr, X_indices, y):
"""Initialize the datafit attributes in sparse dataset case."""
tm, s = y
# initialize_sparse and initialize have the same implementation
# small hack to avoid repetitive code: pass in X_data as only its dtype is used
self.initialize(X_data, y)

tm_as_col = tm.reshape((-1, 1))
self.B = (tm >= tm_as_col).astype(X_data.dtype)
def _A_dot_vec(self, vec):
out = np.zeros_like(vec)
n_H = self.H_indptr.shape[0] - 1

for idx in range(n_H):
current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]]
size_current_H = current_H_idx.shape[0]
frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H

sum_vec_H = np.sum(vec[current_H_idx])
out[current_H_idx] = sum_vec_H * frac_range

return out

def _AT_dot_vec(self, vec):
out = np.zeros_like(vec)
n_H = self.H_indptr.shape[0] - 1

for idx in range(n_H):
current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]]
size_current_H = current_H_idx.shape[0]
frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H

weighted_sum_vec_H = vec[current_H_idx] @ frac_range
out[current_H_idx] = weighted_sum_vec_H * np.ones(size_current_H)

return out
5 changes: 3 additions & 2 deletions skglm/tests/test_datafits.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def test_gamma():
np.testing.assert_allclose(clf.coef_, gamma_results.params, rtol=1e-6)


def test_cox():
@pytest.mark.parametrize("use_efron", [True, False])
def test_cox(use_efron):
rng = np.random.RandomState(1265)
n_samples, n_features = 10, 30

Expand All @@ -131,7 +132,7 @@ def test_cox():
Xw = X @ w

# check datafit
cox_df = compiled_clone(Cox())
cox_df = compiled_clone(Cox(use_efron))

cox_df.initialize(X, (tm, s))
cox_df.value(y, w, Xw)
Expand Down
16 changes: 9 additions & 7 deletions skglm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def test_mtl_path():
np.testing.assert_allclose(coef_ours, coef_sk, rtol=1e-5)


def test_CoxEstimator():
@pytest.mark.parametrize("use_efron", [True, False])
def test_CoxEstimator(use_efron):
try:
from lifelines import CoxPHFitter
except ModuleNotFoundError:
Expand All @@ -182,8 +183,8 @@ def test_CoxEstimator():
n_samples, n_features = 100, 30
random_state = 1265

tm, s, X = make_dummy_survival_data(n_samples, n_features,
normalize=True, random_state=random_state)
tm, s, X = make_dummy_survival_data(n_samples, n_features, normalize=True,
with_ties=use_efron, random_state=random_state)

# compute alpha_max
B = (tm >= tm[:, None]).astype(X.dtype)
Expand All @@ -193,7 +194,7 @@ def test_CoxEstimator():
alpha = reg * alpha_max

# fit Cox using ProxNewton solver
datafit = compiled_clone(Cox())
datafit = compiled_clone(Cox(use_efron))
penalty = compiled_clone(L1(alpha))

datafit.initialize(X, (tm, s))
Expand Down Expand Up @@ -222,13 +223,14 @@ def test_CoxEstimator():
np.testing.assert_allclose(p_obj_skglm, p_obj_ll, atol=1e-6)


def test_CoxEstimator_sparse():
@pytest.mark.parametrize("use_efron", [True, False])
def test_CoxEstimator_sparse(use_efron):
reg = 1e-2
n_samples, n_features = 100, 30
X_density, random_state = 0.5, 1265

tm, s, X = make_dummy_survival_data(n_samples, n_features, X_density=X_density,
random_state=random_state)
with_ties=use_efron, random_state=random_state)

# compute alpha_max
B = (tm >= tm[:, None]).astype(X.dtype)
Expand All @@ -238,7 +240,7 @@ def test_CoxEstimator_sparse():
alpha = reg * alpha_max

# fit Cox using ProxNewton solver
datafit = compiled_clone(Cox())
datafit = compiled_clone(Cox(use_efron))
penalty = compiled_clone(L1(alpha))

datafit.initialize_sparse(X.data, X.indptr, X.indices, (tm, s))
Expand Down
13 changes: 11 additions & 2 deletions skglm/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def make_correlated_data(


def make_dummy_survival_data(n_samples, n_features, normalize=False,
X_density=1., random_state=None):
X_density=1., with_ties=False, random_state=None):
"""Generate a random dataset for survival analysis.

The design matrix ``X`` is generated according to standard normal, the vector of
Expand All @@ -148,6 +148,10 @@ def make_dummy_survival_data(n_samples, n_features, normalize=False,
The density, proportion of non zero elements, of the design matrix ``X``.
X_density must be in ``(0, 1]``.

with_ties : bool, default=False
Determine if the data contains tied observations: observations with the same
occurrences times ``tm``.

random_state : int, default=None
Determines random number generation for data generation.

Expand All @@ -170,7 +174,12 @@ def make_dummy_survival_data(n_samples, n_features, normalize=False,
X = scipy.sparse.rand(
n_samples, n_features, density=X_density, format="csc", dtype=float)

tm = rng.weibull(a=1, size=n_samples)
if not with_ties:
tm = rng.weibull(a=1, size=n_samples)
else:
unique_tm = rng.weibull(a=1, size=n_samples // 10 + 1)
tm = rng.choice(unique_tm, size=n_samples)

s = rng.choice(2, size=n_samples).astype(float)

if normalize and X_density == 1.:
Expand Down