-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
/
metaestimators.py
293 lines (237 loc) · 9.76 KB
/
metaestimators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
"""Utilities for meta-estimators"""
# Author: Joel Nothman
# Andreas Mueller
# License: BSD
from typing import List, Any
from abc import ABCMeta, abstractmethod
from operator import attrgetter
from functools import update_wrapper
import numpy as np
from ..utils import _safe_indexing
from ..base import BaseEstimator
from ..base import _is_pairwise
__all__ = ["available_if", "if_delegate_has_method"]
class _BaseComposition(BaseEstimator, metaclass=ABCMeta):
"""Handles parameter management for classifiers composed of named estimators."""
steps: List[Any]
@abstractmethod
def __init__(self):
pass
def _get_params(self, attr, deep=True):
out = super().get_params(deep=deep)
if not deep:
return out
estimators = getattr(self, attr)
out.update(estimators)
for name, estimator in estimators:
if hasattr(estimator, "get_params"):
for key, value in estimator.get_params(deep=True).items():
out["%s__%s" % (name, key)] = value
return out
def _set_params(self, attr, **params):
# Ensure strict ordering of parameter setting:
# 1. All steps
if attr in params:
setattr(self, attr, params.pop(attr))
# 2. Step replacement
items = getattr(self, attr)
names = []
if items:
names, _ = zip(*items)
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_estimator(attr, name, params.pop(name))
# 3. Step parameters and other initialisation arguments
super().set_params(**params)
return self
def _replace_estimator(self, attr, name, new_val):
# assumes `name` is a valid estimator name
new_estimators = list(getattr(self, attr))
for i, (estimator_name, _) in enumerate(new_estimators):
if estimator_name == name:
new_estimators[i] = (name, new_val)
break
setattr(self, attr, new_estimators)
def _validate_names(self, names):
if len(set(names)) != len(names):
raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
invalid_names = set(names).intersection(self.get_params(deep=False))
if invalid_names:
raise ValueError(
"Estimator names conflict with constructor arguments: {0!r}".format(
sorted(invalid_names)
)
)
invalid_names = [name for name in names if "__" in name]
if invalid_names:
raise ValueError(
"Estimator names must not contain __: got {0!r}".format(invalid_names)
)
class _AvailableIfDescriptor:
"""Implements a conditional property using the descriptor protocol.
Using this class to create a decorator will raise an ``AttributeError``
if check(self) returns a falsey value. Note that if check raises an error
this will also result in hasattr returning false.
See https://docs.python.org/3/howto/descriptor.html for an explanation of
descriptors.
"""
def __init__(self, fn, check, attribute_name):
self.fn = fn
self.check = check
self.attribute_name = attribute_name
# update the docstring of the descriptor
update_wrapper(self, fn)
def __get__(self, obj, owner=None):
attr_err = AttributeError(
f"This {repr(owner.__name__)} has no attribute {repr(self.attribute_name)}"
)
if obj is not None:
# delegate only on instances, not the classes.
# this is to allow access to the docstrings.
if not self.check(obj):
raise attr_err
# lambda, but not partial, allows help() to work with update_wrapper
out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs) # noqa
else:
def fn(*args, **kwargs):
if not self.check(args[0]):
raise attr_err
return self.fn(*args, **kwargs)
# This makes it possible to use the decorated method as an unbound method,
# for instance when monkeypatching.
out = lambda *args, **kwargs: fn(*args, **kwargs) # noqa
# update the docstring of the returned function
update_wrapper(out, self.fn)
return out
def available_if(check):
"""An attribute that is available only if check returns a truthy value
Parameters
----------
check : callable
When passed the object with the decorated method, this should return
a truthy value if the attribute is available, and either return False
or raise an AttributeError if not available.
Examples
--------
>>> from sklearn.utils.metaestimators import available_if
>>> class HelloIfEven:
... def __init__(self, x):
... self.x = x
...
... def _x_is_even(self):
... return self.x % 2 == 0
...
... @available_if(_x_is_even)
... def say_hello(self):
... print("Hello")
...
>>> obj = HelloIfEven(1)
>>> hasattr(obj, "say_hello")
False
>>> obj.x = 2
>>> hasattr(obj, "say_hello")
True
>>> obj.say_hello()
Hello
"""
return lambda fn: _AvailableIfDescriptor(fn, check, attribute_name=fn.__name__)
class _IffHasAttrDescriptor(_AvailableIfDescriptor):
"""Implements a conditional property using the descriptor protocol.
Using this class to create a decorator will raise an ``AttributeError``
if none of the delegates (specified in ``delegate_names``) is an attribute
of the base object or the first found delegate does not have an attribute
``attribute_name``.
This allows ducktyping of the decorated method based on
``delegate.attribute_name``. Here ``delegate`` is the first item in
``delegate_names`` for which ``hasattr(object, delegate) is True``.
See https://docs.python.org/3/howto/descriptor.html for an explanation of
descriptors.
"""
def __init__(self, fn, delegate_names, attribute_name):
super().__init__(fn, self._check, attribute_name)
self.delegate_names = delegate_names
def _check(self, obj):
delegate = None
for delegate_name in self.delegate_names:
try:
delegate = attrgetter(delegate_name)(obj)
break
except AttributeError:
continue
if delegate is None:
return False
# raise original AttributeError
return getattr(delegate, self.attribute_name) or True
def if_delegate_has_method(delegate):
"""Create a decorator for methods that are delegated to a sub-estimator
This enables ducktyping by hasattr returning True according to the
sub-estimator.
Parameters
----------
delegate : string, list of strings or tuple of strings
Name of the sub-estimator that can be accessed as an attribute of the
base object. If a list or a tuple of names are provided, the first
sub-estimator that is an attribute of the base object will be used.
"""
if isinstance(delegate, list):
delegate = tuple(delegate)
if not isinstance(delegate, tuple):
delegate = (delegate,)
return lambda fn: _IffHasAttrDescriptor(fn, delegate, attribute_name=fn.__name__)
def _safe_split(estimator, X, y, indices, train_indices=None):
"""Create subset of dataset and properly handle kernels.
Slice X, y according to indices for cross-validation, but take care of
precomputed kernel-matrices or pairwise affinities / distances.
If ``estimator._pairwise is True``, X needs to be square and
we slice rows and columns. If ``train_indices`` is not None,
we slice rows using ``indices`` (assumed the test set) and columns
using ``train_indices``, indicating the training set.
.. deprecated:: 0.24
The _pairwise attribute is deprecated in 0.24. From 1.1
(renaming of 0.26) and onward, this function will check for the
pairwise estimator tag.
Labels y will always be indexed only along the first axis.
Parameters
----------
estimator : object
Estimator to determine whether we should slice only rows or rows and
columns.
X : array-like, sparse matrix or iterable
Data to be indexed. If ``estimator._pairwise is True``,
this needs to be a square array-like or sparse matrix.
y : array-like, sparse matrix or iterable
Targets to be indexed.
indices : array of int
Rows to select from X and y.
If ``estimator._pairwise is True`` and ``train_indices is None``
then ``indices`` will also be used to slice columns.
train_indices : array of int or None, default=None
If ``estimator._pairwise is True`` and ``train_indices is not None``,
then ``train_indices`` will be use to slice the columns of X.
Returns
-------
X_subset : array-like, sparse matrix or list
Indexed data.
y_subset : array-like, sparse matrix or list
Indexed targets.
"""
if _is_pairwise(estimator):
if not hasattr(X, "shape"):
raise ValueError(
"Precomputed kernels or affinity matrices have "
"to be passed as arrays or sparse matrices."
)
# X is a precomputed square kernel matrix
if X.shape[0] != X.shape[1]:
raise ValueError("X should be a square kernel matrix")
if train_indices is None:
X_subset = X[np.ix_(indices, indices)]
else:
X_subset = X[np.ix_(indices, train_indices)]
else:
X_subset = _safe_indexing(X, indices)
if y is not None:
y_subset = _safe_indexing(y, indices)
else:
y_subset = None
return X_subset, y_subset