-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
Description
Describe the bug
Context: #15722 (comment)
Calling clone()
on a BaseEstimator that copies parameters results in a RuntimeError, even if the parameters are otherwise equal (estimator.param == clone(estimator).param
returns True but estimator.param is clone(estimator.param)
returns False).
Either the documentation has an issue in that this is an unspecified requirement for clone()
to work (and BaseEstimator __init__()
and get_params()
documentation should say that parameters must always be the same object), or the equality check in clone()
is too strict and should be loosened.
Steps/Code to Reproduce
from sklearn.base import BaseEstimator, clone
class TestEstimator(BaseEstimator):
def __init__(self, my_dict):
self.my_dict = my_dict.copy()
some_dict = {'foo': 'bar'}
estimator = TestEstimator(some_dict)
clone(estimator) # raises RuntimeError: Cannot clone object TestEstimator(my_dict={'foo': 'bar'}), as the constructor either does not set or modifies parameter my_dict
Expected Results
Calling clone(estimator)
results in a new TestEstimator where the following assertions are true:
from sklearn.base import BaseEstimator, clone
class TestEstimator(BaseEstimator):
def __init__(self, my_dict):
self.my_dict = my_dict.copy()
some_dict = {'foo': 'bar'}
estimator = TestEstimator(some_dict)
new_estimator = clone(estimator)
assert estimator is not new_estimator
assert estimator.some_dict == new_estimator.some_dict # this isn't strictly necessary, but if clone() is going to assert equality then this seems like the right kind of check
assert estimator.some_dict is not new_estimator.some_dict
# no RuntimeError or AssertionError should be raised after running this snippet
Actual Results
>>> clone(estimator)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/wchill/miniconda/envs/test38/lib/python3.8/site-packages/sklearn/base.py", line 95, in clone
raise RuntimeError(
RuntimeError: Cannot clone object TestEstimator(my_dict={'foo': 'bar'}), as the constructor either does not set or modifies parameter my_dict
Versions
>>> import sklearn; sklearn.show_versions()
System:
python: 3.8.12 (default, Oct 12 2021, 06:23:56) [Clang 10.0.0 ]
executable: /Users/wchill/miniconda/envs/test38/bin/python
machine: macOS-10.16-x86_64-i386-64bit
Python dependencies:
pip: 21.2.4
setuptools: 58.0.4
sklearn: 1.0.2
numpy: 1.22.3
scipy: 1.8.0
Cython: None
pandas: None
matplotlib: None
joblib: 1.1.0
threadpoolctl: 3.1.0
Built with OpenMP: True