diff --git a/sksurv/__init__.py b/sksurv/__init__.py index 96a68138..26366fbf 100644 --- a/sksurv/__init__.py +++ b/sksurv/__init__.py @@ -6,6 +6,8 @@ from sklearn.pipeline import Pipeline, _final_estimator_has from sklearn.utils.metaestimators import available_if +from .util import property_available_if + def _get_version(name): try: @@ -125,9 +127,15 @@ def predict_survival_function(self, X, **kwargs): return self.steps[-1][-1].predict_survival_function(Xt, **kwargs) +@property_available_if(_final_estimator_has("_predict_risk_score")) +def _predict_risk_score(self): + return self.steps[-1][-1]._predict_risk_score + + def patch_pipeline(): Pipeline.predict_survival_function = predict_survival_function Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function + Pipeline._predict_risk_score = _predict_risk_score try: diff --git a/sksurv/util.py b/sksurv/util.py index b8778143..92f32852 100644 --- a/sksurv/util.py +++ b/sksurv/util.py @@ -259,3 +259,63 @@ def safe_concat(objs, *args, **kwargs): concatenated[name] = pd.Categorical(concatenated[name], **params) return concatenated + + +class _PropertyAvailableIfDescriptor: + """Implements a conditional property using the descriptor protocol based on the property decorator. + + The corresponding class in scikit-learn (`_AvailableIfDescriptor`) only supports callables. + This class adopts the property decorator as described in the descriptor guide in the offical Python documentation. + + See also + -------- + https://docs.python.org/3/howto/descriptor.html + Descriptor HowTo Guide + + :class:`sklearn.utils.available_if._AvailableIfDescriptor` + The original class in scikit-learn. + """ + + def __init__(self, check, fget, doc=None): + self.check = check + self.fget = fget + if doc is None and fget is not None: + doc = fget.__doc__ + self.__doc__ = doc + self._name = "" + + def __set_name__(self, owner, name): + self._name = name + + def __get__(self, obj, objtype=None): + if obj is None: + return self + + attr_err = AttributeError(f"This {obj!r} has no attribute {self._name!r}") + if not self.check(obj): + raise attr_err + + if self.fget is None: + raise AttributeError(f"property '{self._name}' has no getter") + return self.fget(obj) + + +def property_available_if(check): + """A property attribute that is available only if check returns a truthy value. + + Only supports getting an attribute value, setting or deleting an attribute value are not supported. + + Parameters + ---------- + check : callable + When passed the object of the decorated method, this should return + `True` if the property attribute is available, and either return `False` + or raise an `AttributeError` if not available. + + Returns + ------- + callable + Callable makes the decorated property available if `check` returns + `True`, otherwise the decorated property is unavailable. + """ + return lambda fn: _PropertyAvailableIfDescriptor(check=check, fget=fn) diff --git a/tests/test_aft.py b/tests/test_aft.py index 2c584744..bd36e991 100644 --- a/tests/test_aft.py +++ b/tests/test_aft.py @@ -1,7 +1,9 @@ import numpy as np from numpy.testing import assert_array_almost_equal import pytest +from sklearn.pipeline import make_pipeline +from sksurv.base import SurvivalAnalysisMixin from sksurv.linear_model import IPCRidge from sksurv.testing import assert_cindex_almost_equal @@ -51,3 +53,21 @@ def test_predict(make_whas500): ) assert model.score(x_test, y_test) == 0.66925817946226107 + + @staticmethod + def test_pipeline_score(make_whas500): + whas500 = make_whas500() + pipe = make_pipeline(IPCRidge()) + pipe.fit(whas500.x[:400], whas500.y[:400]) + + x_test = whas500.x[400:] + y_test = whas500.y[400:] + p = pipe.predict(x_test) + assert_cindex_almost_equal( + y_test["fstat"], + y_test["lenfol"], + -p, + (0.66925817946226107, 2066, 1021, 0, 1), + ) + + assert SurvivalAnalysisMixin.score(pipe, x_test, y_test) == 0.66925817946226107 diff --git a/tests/test_util.py b/tests/test_util.py index 590eb49f..fc3fd71d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,7 +8,7 @@ import pytest from sksurv.testing import FixtureParameterFactory -from sksurv.util import Surv, safe_concat +from sksurv.util import Surv, _PropertyAvailableIfDescriptor, property_available_if, safe_concat class ConcatCasesFactory(FixtureParameterFactory): @@ -369,3 +369,37 @@ def test_from_dataframe(args, expected, expected_error): if expected is not None: assert_array_equal(y, expected) + + +def test_cond_avail_property(): + class WithCondProp: + def __init__(self, val): + self.avail = False + self._prop = val + + @property_available_if(lambda self: self.avail) + def prop(self): + return self._prop + + no_prop = _PropertyAvailableIfDescriptor(lambda self: self.avail, fget=None) + + testval = 43 + msg = "has no attribute 'prop'" + + assert WithCondProp.prop is not None + + test_obj = WithCondProp(testval) + with pytest.raises(AttributeError, match=msg): + _ = test_obj.prop + assert test_obj.avail is False + + test_obj.avail = True + assert test_obj.prop == testval + + test_obj.avail = False + with pytest.raises(AttributeError, match=msg): + _ = test_obj.prop + + test_obj.avail = True + with pytest.raises(AttributeError, match="has no getter"): + _ = test_obj.no_prop