Skip to content

RuntimeError: "Cannot clone object ..." when cloning an estimator that copies parameters in either __init__ or get_params #22857

@wchill

Description

@wchill

Describe the bug

Context: #15722 (comment)

Calling clone() on a BaseEstimator that copies parameters results in a RuntimeError, even if the parameters are otherwise equal (estimator.param == clone(estimator).param returns True but estimator.param is clone(estimator.param) returns False).

Either the documentation has an issue in that this is an unspecified requirement for clone() to work (and BaseEstimator __init__() and get_params() documentation should say that parameters must always be the same object), or the equality check in clone() is too strict and should be loosened.

Steps/Code to Reproduce

from sklearn.base import BaseEstimator, clone

class TestEstimator(BaseEstimator):
    def __init__(self, my_dict):
        self.my_dict = my_dict.copy()

some_dict = {'foo': 'bar'}
estimator = TestEstimator(some_dict)

clone(estimator)      # raises RuntimeError: Cannot clone object TestEstimator(my_dict={'foo': 'bar'}), as the constructor either does not set or modifies parameter my_dict

Expected Results

Calling clone(estimator) results in a new TestEstimator where the following assertions are true:

from sklearn.base import BaseEstimator, clone

class TestEstimator(BaseEstimator):
    def __init__(self, my_dict):
        self.my_dict = my_dict.copy()

some_dict = {'foo': 'bar'}
estimator = TestEstimator(some_dict)

new_estimator = clone(estimator)

assert estimator is not new_estimator
assert estimator.some_dict == new_estimator.some_dict        #  this isn't strictly necessary, but if clone() is going to assert equality then this seems like the right kind of check
assert estimator.some_dict is not new_estimator.some_dict

# no RuntimeError or AssertionError should be raised after running this snippet

Actual Results

>>> clone(estimator)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/wchill/miniconda/envs/test38/lib/python3.8/site-packages/sklearn/base.py", line 95, in clone
    raise RuntimeError(
RuntimeError: Cannot clone object TestEstimator(my_dict={'foo': 'bar'}), as the constructor either does not set or modifies parameter my_dict

Versions

>>> import sklearn; sklearn.show_versions()

System:
    python: 3.8.12 (default, Oct 12 2021, 06:23:56)  [Clang 10.0.0 ]
executable: /Users/wchill/miniconda/envs/test38/bin/python
   machine: macOS-10.16-x86_64-i386-64bit

Python dependencies:
          pip: 21.2.4
   setuptools: 58.0.4
      sklearn: 1.0.2
        numpy: 1.22.3
        scipy: 1.8.0
       Cython: None
       pandas: None
   matplotlib: None
       joblib: 1.1.0
threadpoolctl: 3.1.0

Built with OpenMP: True

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions