-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
fixed.py
89 lines (69 loc) · 3.03 KB
/
fixed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""Parameter estimator with fixed parameters."""
__author__ = ["fkiraly"]
__all__ = ["FixedParams"]
from sktime.datatypes import ALL_TIME_SERIES_MTYPES
from sktime.param_est.base import BaseParamFitter
class FixedParams(BaseParamFitter):
"""Dummy parameter estimator that writes fixed values to self.
This can be used as a dummy/mock, or as a pipeline element, e.g.,
to set parameters to certain values, or in model selection as the "fixed" option.
Takes a dictionary ``param_dict`` of name/value pairs to write to self in ``fit``.
In ``fit``, for each key-value pair in ``param_dict``,
writes ``value`` to attribute ``str(key) + "_"`` in ``self``
Parameters
----------
param_dict : dict
fixed parameter values written to ``self``
"""
_tags = {
"authors": "fkiraly",
"X_inner_mtype": ALL_TIME_SERIES_MTYPES,
# which types do _fit/_predict, support for X?
"scitype:X": ["Series", "Panel", "Hierarchical"],
# which X scitypes are supported natively?
"capability:missing_values": True, # can estimator handle missing data?
"capability:multivariate": True, # can estimator handle multivariate data?
}
def __init__(self, param_dict):
self.param_dict = param_dict
super().__init__()
def _fit(self, X):
"""Fit estimator and estimate parameters.
private _fit containing the core logic, called from fit
Writes to self:
Sets fitted model attributes ending in "_".
Parameters
----------
X : guaranteed to be of a type in self.get_tag("X_inner_mtype")
Time series to which to fit the estimator.
Returns
-------
self : reference to self
"""
param_dict = self.param_dict
if isinstance(param_dict, dict):
for key, value in param_dict.items():
setattr(self, f"{str(key)}_", value)
return self
@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return ``"default"`` set.
There are currently no reserved values for transformers.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
``MyClass(**params)`` or ``MyClass(**params[i])`` creates a valid test
instance.
``create_test_instance`` uses the first (or only) dictionary in ``params``
"""
params1 = {"param_dict": {1: 2}}
params2 = {"param_dict": {"foo": "bar", "bar": "foo"}}
return [params1, params2]