Skip to content

Commit

Permalink
Merge pull request #374 from sebp/Finesim97-pipeline_scale_fix
Browse files Browse the repository at this point in the history
Expose _predict_risk_score in Pipeline
  • Loading branch information
sebp committed Jun 10, 2023
2 parents e376da4 + 309ef8e commit c65ab0b
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 1 deletion.
8 changes: 8 additions & 0 deletions sksurv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from sklearn.pipeline import Pipeline, _final_estimator_has
from sklearn.utils.metaestimators import available_if

from .util import property_available_if


def _get_version(name):
try:
Expand Down Expand Up @@ -125,9 +127,15 @@ def predict_survival_function(self, X, **kwargs):
return self.steps[-1][-1].predict_survival_function(Xt, **kwargs)


@property_available_if(_final_estimator_has("_predict_risk_score"))
def _predict_risk_score(self):
return self.steps[-1][-1]._predict_risk_score


def patch_pipeline():
Pipeline.predict_survival_function = predict_survival_function
Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function
Pipeline._predict_risk_score = _predict_risk_score


try:
Expand Down
60 changes: 60 additions & 0 deletions sksurv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,63 @@ def safe_concat(objs, *args, **kwargs):
concatenated[name] = pd.Categorical(concatenated[name], **params)

return concatenated


class _PropertyAvailableIfDescriptor:
"""Implements a conditional property using the descriptor protocol based on the property decorator.
The corresponding class in scikit-learn (`_AvailableIfDescriptor`) only supports callables.
This class adopts the property decorator as described in the descriptor guide in the offical Python documentation.
See also
--------
https://docs.python.org/3/howto/descriptor.html
Descriptor HowTo Guide
:class:`sklearn.utils.available_if._AvailableIfDescriptor`
The original class in scikit-learn.
"""

def __init__(self, check, fget, doc=None):
self.check = check
self.fget = fget
if doc is None and fget is not None:
doc = fget.__doc__
self.__doc__ = doc
self._name = ""

def __set_name__(self, owner, name):
self._name = name

def __get__(self, obj, objtype=None):
if obj is None:
return self

attr_err = AttributeError(f"This {obj!r} has no attribute {self._name!r}")
if not self.check(obj):
raise attr_err

if self.fget is None:
raise AttributeError(f"property '{self._name}' has no getter")
return self.fget(obj)


def property_available_if(check):
"""A property attribute that is available only if check returns a truthy value.
Only supports getting an attribute value, setting or deleting an attribute value are not supported.
Parameters
----------
check : callable
When passed the object of the decorated method, this should return
`True` if the property attribute is available, and either return `False`
or raise an `AttributeError` if not available.
Returns
-------
callable
Callable makes the decorated property available if `check` returns
`True`, otherwise the decorated property is unavailable.
"""
return lambda fn: _PropertyAvailableIfDescriptor(check=check, fget=fn)
20 changes: 20 additions & 0 deletions tests/test_aft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from numpy.testing import assert_array_almost_equal
import pytest
from sklearn.pipeline import make_pipeline

from sksurv.base import SurvivalAnalysisMixin
from sksurv.linear_model import IPCRidge
from sksurv.testing import assert_cindex_almost_equal

Expand Down Expand Up @@ -51,3 +53,21 @@ def test_predict(make_whas500):
)

assert model.score(x_test, y_test) == 0.66925817946226107

@staticmethod
def test_pipeline_score(make_whas500):
whas500 = make_whas500()
pipe = make_pipeline(IPCRidge())
pipe.fit(whas500.x[:400], whas500.y[:400])

x_test = whas500.x[400:]
y_test = whas500.y[400:]
p = pipe.predict(x_test)
assert_cindex_almost_equal(
y_test["fstat"],
y_test["lenfol"],
-p,
(0.66925817946226107, 2066, 1021, 0, 1),
)

assert SurvivalAnalysisMixin.score(pipe, x_test, y_test) == 0.66925817946226107
36 changes: 35 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

from sksurv.testing import FixtureParameterFactory
from sksurv.util import Surv, safe_concat
from sksurv.util import Surv, _PropertyAvailableIfDescriptor, property_available_if, safe_concat


class ConcatCasesFactory(FixtureParameterFactory):
Expand Down Expand Up @@ -369,3 +369,37 @@ def test_from_dataframe(args, expected, expected_error):

if expected is not None:
assert_array_equal(y, expected)


def test_cond_avail_property():
class WithCondProp:
def __init__(self, val):
self.avail = False
self._prop = val

@property_available_if(lambda self: self.avail)
def prop(self):
return self._prop

no_prop = _PropertyAvailableIfDescriptor(lambda self: self.avail, fget=None)

testval = 43
msg = "has no attribute 'prop'"

assert WithCondProp.prop is not None

test_obj = WithCondProp(testval)
with pytest.raises(AttributeError, match=msg):
_ = test_obj.prop
assert test_obj.avail is False

test_obj.avail = True
assert test_obj.prop == testval

test_obj.avail = False
with pytest.raises(AttributeError, match=msg):
_ = test_obj.prop

test_obj.avail = True
with pytest.raises(AttributeError, match="has no getter"):
_ = test_obj.no_prop

0 comments on commit c65ab0b

Please sign in to comment.