Skip to content

Commit

Permalink
[ENH] BaseObject reset functionality (#2531)
Browse files Browse the repository at this point in the history
Provides the functionality discussed in #1614, i.e., a `reset` method in `BaseObject` that resets attributes and internal state of an estimator/object to its post-`__init__` state (while keeping parameters), equivalent to overwriting `self` with a `sklearn.clone`.

This is motivated by the bug in `MultiplexForecaster` which @miraep8's test suite uncovered in combination with #2458, and the observation that:
* boilerplate needs to be copied from `__init__` to `_fit` to fix it, and
* that the same issue possibly exists in a number of other composites which are not tested for subsequent `fit` with different parameters.

This PR also adds `reset` at the start of `BaseForecaster.fit`, addressing a family of potential bugs where not initializing in `_fit` causes unexpected behaviour in a sequence of calls `__init__`, `set_params`, `fit` (the bug present in the original #2458).
  • Loading branch information
fkiraly committed Apr 27, 2022
1 parent 2c82c92 commit 7366622
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
39 changes: 36 additions & 3 deletions sktime/base/_base.py
Expand Up @@ -14,9 +14,9 @@
class name: BaseObject
Hyper-parameter inspection and setter methods:
inspect hyper-parameters - get_params()
setting hyper-parameters - set_params(**params)
Hyper-parameter inspection and setter methods
inspect hyper-parameters - get_params()
setting hyper-parameters - set_params(**params)
Tag inspection and setter methods
inspect tags (all) - get_tags()
Expand All @@ -26,6 +26,9 @@ class name: BaseObject
setting dynamic tags - set_tag(**tag_dict: dict)
set/clone dynamic tags - clone_tags(estimator, tag_names=None)
Re-initialize object to post-init state with same hyper-parameters
reset estimator to post-init - reset()
Testing with default parameters methods
getting default parameters (all sets) - get_test_params()
get one test instance with default parameters - create_test_instance()
Expand Down Expand Up @@ -69,6 +72,36 @@ def __init__(self):
self._tags_dynamic = dict()
super(BaseObject, self).__init__()

def reset(self):
"""Reset the object to a clean post-init state.
Equivalent to sklearn.clone but overwrites self.
After self.reset() call, self is equal in value to
`type(self)(**self.get_params(deep=False))`
Detail behaviour:
removes any object attributes, except:
hyper-parameters = arguments of __init__
object attributes containing double-underscores, i.e., the string "__"
runs __init__ with current values of hyper-parameters (result of get_params)
Not affected by the reset are:
object attributes containing double-underscores
class and object methods, class attributes
"""
# retrieve parameters to copy them later
params = self.get_params(deep=False)

# delete all object attributes in self
attrs = [attr for attr in dir(self) if "__" not in attr]
cls_attrs = [attr for attr in dir(type(self))]
self_attrs = set(attrs).difference(cls_attrs)
for attr in self_attrs:
delattr(self, attr)

# run init with a copy of parameters self had at the start
self.__init__(**params)

@classmethod
def get_class_tags(cls):
"""Get class tags from estimator class and all its parent classes.
Expand Down
43 changes: 43 additions & 0 deletions sktime/base/tests/test_base.py
Expand Up @@ -186,3 +186,46 @@ def test_is_composite():

assert not non_composite.is_composite()
assert composite.is_composite()


class ResetTester(BaseObject):

clsvar = 210

def __init__(self, a, b=42):
self.a = a
self.b = b
self.c = 84

def foo(self):
self.d = 126
self._d = 126
self.d_ = 126
self.f__o__o = 252


def test_reset():
"""Tests reset method for correct behaviour.
Raises
------
AssertionError if logic behind reset is incorrect, logic tested:
reset should remove any object attributes that are not hyper-parameters,
with the exception of attributes containing double-underscore "__"
reset should not remove class attributes or methods
reset should set hyper-parameters as in pre-reset state
"""
x = ResetTester(168)
x.foo()

x.reset()

assert hasattr(x, "a") and x.a == 168
assert hasattr(x, "b") and x.b == 42
assert hasattr(x, "c") and x.c == 84
assert hasattr(x, "clsvar") and x.clsvar == 210
assert not hasattr(x, "d")
assert not hasattr(x, "_d")
assert not hasattr(x, "d_")
assert hasattr(x, "f__o__o") and x.f__o__o == 252
assert hasattr(x, "foo")
2 changes: 2 additions & 0 deletions sktime/forecasting/base/_base.py
Expand Up @@ -220,6 +220,8 @@ def fit(self, y, X=None, fh=None):
# check y is not None
assert y is not None, "y cannot be None, but found None"

# if fit is called, object is reset
self.reset()
# if fit is called, fitted state is re-set
self._is_fitted = False

Expand Down

0 comments on commit 7366622

Please sign in to comment.