-
Notifications
You must be signed in to change notification settings - Fork 206
/
test_pandas_inputs.py
67 lines (55 loc) · 2.26 KB
/
test_pandas_inputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from importlib import import_module
import inspect
from os.path import dirname
import pkgutil
import numpy
from numpy.testing import assert_array_equal
import pandas
import pytest
import sksurv
from sksurv.base import SurvivalAnalysisMixin
from sksurv.datasets import load_whas500
def is_survival_mixin(x):
return inspect.isclass(x) and x is not SurvivalAnalysisMixin and issubclass(x, SurvivalAnalysisMixin)
def all_survival_estimators():
root = dirname(sksurv.__file__)
all_classes = []
for _importer, modname, _ispkg in pkgutil.walk_packages(path=[root], prefix="sksurv."):
# meta-estimators require base estimators
if modname.startswith("sksurv.meta"):
continue
module = import_module(modname)
for _name, cls in inspect.getmembers(module, is_survival_mixin):
if inspect.isabstract(cls):
continue
all_classes.append(cls)
return set(all_classes)
@pytest.mark.parametrize("estimator_cls", all_survival_estimators())
def test_pandas_inputs(estimator_cls):
X, y = load_whas500()
X = X.iloc[:50]
y = y[:50]
X_df = X.loc[:, ["age", "bmi", "chf", "gender"]].astype(float)
X_np = X_df.values
estimator = estimator_cls()
if "kernel" in estimator.get_params():
estimator.set_params(kernel="rbf")
estimator.fit(X_df, y)
assert hasattr(estimator, "feature_names_in_")
assert_array_equal(estimator.feature_names_in_, numpy.asarray(X_df.columns, dtype=object))
estimator.predict(X_df)
msg = "The feature names should match those that were passed"
X_bad = pandas.DataFrame(X_np, columns=X_df.columns.tolist()[::-1])
with pytest.warns(FutureWarning, match=msg):
estimator.predict(X_bad)
# warns when fitted on dataframe and transforming a ndarray
msg = "X does not have valid feature names, but {} was fitted with feature names".format(
estimator_cls.__name__
)
with pytest.warns(UserWarning, match=msg):
estimator.predict(X_np)
# warns when fitted on a ndarray and transforming dataframe
msg = "X has feature names, but {} was fitted without feature names".format(estimator_cls.__name__)
estimator.fit(X_np, y)
with pytest.warns(UserWarning, match=msg):
estimator.predict(X_df)