-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
/
discovery.py
217 lines (187 loc) · 7.17 KB
/
discovery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import pkgutil
import inspect
from importlib import import_module
from operator import itemgetter
from pathlib import Path
_MODULE_TO_IGNORE = {
"tests",
"externals",
"setup",
"conftest",
"experimental",
"estimator_checks",
}
def all_estimators(type_filter=None):
"""Get a list of all estimators from `sklearn`.
This function crawls the module and gets all classes that inherit
from BaseEstimator. Classes that are defined in test-modules are not
included.
Parameters
----------
type_filter : {"classifier", "regressor", "cluster", "transformer"} \
or list of such str, default=None
Which kind of estimators should be returned. If None, no filter is
applied and all estimators are returned. Possible values are
'classifier', 'regressor', 'cluster' and 'transformer' to get
estimators only of these specific types, or a list of these to
get the estimators that fit at least one of the types.
Returns
-------
estimators : list of tuples
List of (name, class), where ``name`` is the class name as string
and ``class`` is the actual type of the class.
"""
# lazy import to avoid circular imports from sklearn.base
from . import IS_PYPY
from ._testing import ignore_warnings
from ..base import (
BaseEstimator,
ClassifierMixin,
RegressorMixin,
TransformerMixin,
ClusterMixin,
)
def is_abstract(c):
if not (hasattr(c, "__abstractmethods__")):
return False
if not len(c.__abstractmethods__):
return False
return True
all_classes = []
root = str(Path(__file__).parent.parent) # sklearn package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
module_parts = module_name.split(".")
if (
any(part in _MODULE_TO_IGNORE for part in module_parts)
or "._" in module_name
):
continue
module = import_module(module_name)
classes = inspect.getmembers(module, inspect.isclass)
classes = [
(name, est_cls) for name, est_cls in classes if not name.startswith("_")
]
# TODO: Remove when FeatureHasher is implemented in PYPY
# Skips FeatureHasher for PYPY
if IS_PYPY and "feature_extraction" in module_name:
classes = [
(name, est_cls)
for name, est_cls in classes
if name == "FeatureHasher"
]
all_classes.extend(classes)
all_classes = set(all_classes)
estimators = [
c
for c in all_classes
if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
]
# get rid of abstract base classes
estimators = [c for c in estimators if not is_abstract(c[1])]
if type_filter is not None:
if not isinstance(type_filter, list):
type_filter = [type_filter]
else:
type_filter = list(type_filter) # copy
filtered_estimators = []
filters = {
"classifier": ClassifierMixin,
"regressor": RegressorMixin,
"transformer": TransformerMixin,
"cluster": ClusterMixin,
}
for name, mixin in filters.items():
if name in type_filter:
type_filter.remove(name)
filtered_estimators.extend(
[est for est in estimators if issubclass(est[1], mixin)]
)
estimators = filtered_estimators
if type_filter:
raise ValueError(
"Parameter type_filter must be 'classifier', "
"'regressor', 'transformer', 'cluster' or "
"None, got"
f" {repr(type_filter)}."
)
# drop duplicates, sort for reproducibility
# itemgetter is used to ensure the sort does not extend to the 2nd item of
# the tuple
return sorted(set(estimators), key=itemgetter(0))
def all_displays():
"""Get a list of all displays from `sklearn`.
Returns
-------
displays : list of tuples
List of (name, class), where ``name`` is the display class name as
string and ``class`` is the actual type of the class.
"""
# lazy import to avoid circular imports from sklearn.base
from ._testing import ignore_warnings
all_classes = []
root = str(Path(__file__).parent.parent) # sklearn package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
module_parts = module_name.split(".")
if (
any(part in _MODULE_TO_IGNORE for part in module_parts)
or "._" in module_name
):
continue
module = import_module(module_name)
classes = inspect.getmembers(module, inspect.isclass)
classes = [
(name, display_class)
for name, display_class in classes
if not name.startswith("_") and name.endswith("Display")
]
all_classes.extend(classes)
return sorted(set(all_classes), key=itemgetter(0))
def _is_checked_function(item):
if not inspect.isfunction(item):
return False
if item.__name__.startswith("_"):
return False
mod = item.__module__
if not mod.startswith("sklearn.") or mod.endswith("estimator_checks"):
return False
return True
def all_functions():
"""Get a list of all functions from `sklearn`.
Returns
-------
functions : list of tuples
List of (name, function), where ``name`` is the function name as
string and ``function`` is the actual function.
"""
# lazy import to avoid circular imports from sklearn.base
from ._testing import ignore_warnings
all_functions = []
root = str(Path(__file__).parent.parent) # sklearn package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
module_parts = module_name.split(".")
if (
any(part in _MODULE_TO_IGNORE for part in module_parts)
or "._" in module_name
):
continue
module = import_module(module_name)
functions = inspect.getmembers(module, _is_checked_function)
functions = [
(func.__name__, func)
for name, func in functions
if not name.startswith("_")
]
all_functions.extend(functions)
# drop duplicates, sort for reproducibility
# itemgetter is used to ensure the sort does not extend to the 2nd item of
# the tuple
return sorted(set(all_functions), key=itemgetter(0))