Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Towards #5850, this PR adds a general interface towards `pyts`, and an interface to `pyts` `ROCKET` (as implicitly suggested in #5849). `pyts` uses the `numpyflat` mtype internally, so the `tslearn` adapter should mostly translate (only diff is mtype, `tslearn` uses `numpy3D`). Mid-term, one could think about refactoring both adapters to incrase DRY-ness, if it should work for `pyts`. `pyts` is added to the `all_extras` dependency set. Notably, `pyts` depends on `numba`, which means 3.11 or lower. For now, I have avoided adding it to the other dependency sets, as it might cause restrictions.
- Loading branch information
Showing
5 changed files
with
322 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""Module containing adapters other framework packages covering multiple tasks.""" | ||
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
|
||
__all__ = ["_TslearnAdapter"] | ||
__all__ = ["_PytsAdapter", "_TslearnAdapter"] | ||
|
||
from sktime.base.adapters._pyts import _PytsAdapter | ||
from sktime.base.adapters._tslearn import _TslearnAdapter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
"""Implements adapter for pyts models.""" | ||
|
||
__all__ = ["_PytsAdapter"] | ||
__author__ = ["fkiraly"] | ||
|
||
from inspect import signature | ||
|
||
|
||
class _PytsAdapter: | ||
"""Mixin adapter class for pyts models.""" | ||
|
||
_tags = { | ||
"X_inner_mtype": "numpyflat", | ||
"python_dependencies": ["pyts"], | ||
} | ||
|
||
# defines the name of the attribute containing the pyts estimator | ||
_estimator_attr = "_estimator" | ||
|
||
def _get_pyts_class(self): | ||
"""Abstract method to get pyts class. | ||
should import and return pyts class | ||
""" | ||
# from pyts import PytsClass | ||
# | ||
# return Pyts | ||
raise NotImplementedError("abstract method") | ||
|
||
def _get_pyts_object(self): | ||
"""Abstract method to initialize pyts object. | ||
The default initializes result of _get_pyts_class | ||
with self.get_params. | ||
""" | ||
cls = self._get_pyts_class() | ||
return cls(**self.get_params()) | ||
|
||
def _init_pyts_object(self): | ||
"""Abstract method to initialize pyts object and set to _estimator_attr. | ||
The default writes the return of _get_pyts_object to | ||
the attribute of self with name _estimator_attr | ||
""" | ||
cls = self._get_pyts_object() | ||
setattr(self, self._estimator_attr, cls) | ||
return getattr(self, self._estimator_attr) | ||
|
||
def _fit(self, X, y=None): | ||
"""Fit estimator training data. | ||
Parameters | ||
---------- | ||
X : 3D np.ndarray of shape (n_instances, n_dimensions, series_length) | ||
Training features, passed only for classifiers or regressors | ||
y: None or 1D np.ndarray of shape (n_instances,) | ||
Training labels, passed only for classifiers or regressors | ||
Returns | ||
------- | ||
self: sktime estimator | ||
Fitted estimator. | ||
""" | ||
pyts_est = self._init_pyts_object() | ||
|
||
# check if pyts_est fit has y parameter | ||
# if yes, call with y, otherwise without | ||
pyts_has_y = "y" in signature(pyts_est.fit).parameters | ||
|
||
if pyts_has_y: | ||
pyts_est.fit(X, y) | ||
else: | ||
pyts_est.fit(X) | ||
|
||
# write fitted params to self | ||
pyts_fitted_params = self._get_fitted_params_default(pyts_est) | ||
for k, v in pyts_fitted_params.items(): | ||
setattr(self, f"{k}_", v) | ||
|
||
return self | ||
|
||
def _transform(self, X, y=None): | ||
"""Transform method adapter. | ||
Parameters | ||
---------- | ||
X : np.ndarray (2d or 3d array of shape (n_instances, series_length) or shape | ||
(n_instances, n_dimensions, series_length)) | ||
y: ignored, exists for API consistency reasons. | ||
Returns | ||
------- | ||
np.ndarray (1d array of shape (n_instances,)) | ||
Index of the cluster each time series in X belongs to. | ||
""" | ||
pyts_est = getattr(self, self._estimator_attr) | ||
|
||
# check if pyts_est fit has y parameter | ||
# if yes, call with y, otherwise without | ||
pyts_has_y = "y" in signature(pyts_est.transform).parameters | ||
|
||
if pyts_has_y: | ||
return pyts_est.transform(X, y) | ||
else: | ||
return pyts_est.transform(X) | ||
|
||
def _predict(self, X, y=None): | ||
"""Predict method adapter. | ||
Parameters | ||
---------- | ||
X : np.ndarray (2d or 3d array of shape (n_instances, series_length) or shape | ||
(n_instances, n_dimensions, series_length)) | ||
y: passed to pyts predict method if it has y parameter | ||
Returns | ||
------- | ||
np.ndarray (1d array of shape (n_instances,)) | ||
Index of the cluster each time series in X belongs to. | ||
""" | ||
pyts_est = getattr(self, self._estimator_attr) | ||
|
||
# check if pyts_est fit has y parameter | ||
# if yes, call with y, otherwise without | ||
pyts_has_y = "y" in signature(pyts_est.predict).parameters | ||
|
||
if pyts_has_y: | ||
return pyts_est.predict(X, y) | ||
else: | ||
return pyts_est.predict(X) | ||
|
||
def _predict_proba(self, X, y=None): | ||
"""Predict_proba method adapter. | ||
Parameters | ||
---------- | ||
X : np.ndarray (2d or 3d array of shape (n_instances, series_length) or shape | ||
(n_instances, n_dimensions, series_length)) | ||
Time series instances to predict their cluster indexes. | ||
y: passed to pyts predict method if it has y parameter | ||
Returns | ||
------- | ||
np.ndarray (1d array of shape (n_instances,)) | ||
Index of the cluster each time series in X belongs to. | ||
""" | ||
pyts_est = getattr(self, self._estimator_attr) | ||
|
||
# check if pyts_est fit has y parameter | ||
# if yes, call with y, otherwise without | ||
pyts_has_y = "y" in signature(pyts_est.predict_proba).parameters | ||
|
||
if pyts_has_y: | ||
return pyts_est.predict_proba(X, y) | ||
else: | ||
return pyts_est.predict_proba(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,24 @@ | ||
"""Rocket transformers.""" | ||
__all__ = [ | ||
"Rocket", | ||
"RocketPyts", | ||
"MiniRocket", | ||
"MiniRocketMultivariate", | ||
"MiniRocketMultivariateVariable", | ||
"MultiRocket", | ||
"MultiRocketMultivariate", | ||
] | ||
|
||
from ._minirocket import MiniRocket | ||
from ._minirocket_multivariate import MiniRocketMultivariate | ||
from ._minirocket_multivariate_variable import MiniRocketMultivariateVariable | ||
from ._multirocket import MultiRocket | ||
from ._multirocket_multivariate import MultiRocketMultivariate | ||
from ._rocket import Rocket | ||
from sktime.transformations.panel.rocket._minirocket import MiniRocket | ||
from sktime.transformations.panel.rocket._minirocket_multivariate import ( | ||
MiniRocketMultivariate, | ||
) | ||
from sktime.transformations.panel.rocket._minirocket_multivariate_variable import ( | ||
MiniRocketMultivariateVariable, | ||
) | ||
from sktime.transformations.panel.rocket._multirocket import MultiRocket | ||
from sktime.transformations.panel.rocket._multirocket_multivariate import ( | ||
MultiRocketMultivariate, | ||
) | ||
from sktime.transformations.panel.rocket._rocket import Rocket | ||
from sktime.transformations.panel.rocket._rocket_pyts import RocketPyts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
"""Rocket transformer, from pyts.""" | ||
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
|
||
__author__ = ["fkiraly"] | ||
__all__ = ["RocketPyts"] | ||
|
||
from sktime.base.adapters._pyts import _PytsAdapter | ||
from sktime.transformations.base import BaseTransformer | ||
|
||
|
||
class RocketPyts(_PytsAdapter, BaseTransformer): | ||
"""RandOm Convolutional KErnel Transform (ROCKET), from ``pyts``. | ||
Direct interface to ``pyts.transformation.rocket``. | ||
ROCKET [1]_ generates random convolutional kernels, including random length and | ||
dilation. It transforms the time series with two features per kernel. The first | ||
feature is global max pooling and the second is proportion of positive values. | ||
This transformer fits one set of paramereters per individual series, | ||
and applies the transform with fitted parameter i to the i-th series in transform. | ||
Vanilla use requires same number of series in fit and transform. | ||
To fit and transform series at the same time, | ||
without an identification of fit/transform instances, | ||
wrap this transformer in ``FitInTransform``, | ||
from ``sktime.transformations.compose``. | ||
Parameters | ||
---------- | ||
n_kernels : int (default = 10000) | ||
Number of kernels. | ||
kernel_sizes : array-like (default = (7, 9, 11)) | ||
The possible sizes of the kernels. | ||
random_state : None, int or RandomState instance (default = None) | ||
The seed of the pseudo random number generator to use when shuffling | ||
the data. If int, random_state is the seed used by the random number | ||
generator. If RandomState instance, random_state is the random number | ||
generator. If None, the random number generator is the RandomState | ||
instance used by `np.random`. | ||
Attributes | ||
---------- | ||
weights_ : array, shape = (n_kernels, max(kernel_sizes)) | ||
Weights of the kernels. Zero padding values are added. | ||
length_ : array, shape = (n_kernels,) | ||
Length of each kernel. | ||
bias_ : array, shape = (n_kernels,) | ||
Bias of each kernel. | ||
dilation_ : array, shape = (n_kernels,) | ||
Dilation of each kernel. | ||
padding_ : array, shape = (n_kernels,) | ||
Padding of each kernel. | ||
See Also | ||
-------- | ||
MultiRocketMultivariate, MiniRocket, MiniRocketMultivariate, Rocket | ||
References | ||
---------- | ||
.. [1] Tan, Chang Wei and Dempster, Angus and Bergmeir, Christoph | ||
and Webb, Geoffrey I, | ||
"ROCKET: Exceptionally fast and accurate time series | ||
classification using random convolutional kernels",2020, | ||
https://link.springer.com/article/10.1007/s10618-020-00701-z, | ||
https://arxiv.org/abs/1910.13051 | ||
Examples | ||
-------- | ||
>>> from sktime.transformations.panel.rocket import RocketPyts | ||
>>> from sktime.datasets import load_unit_test | ||
>>> X_train, y_train = load_unit_test(split="train") # doctest: +SKIP | ||
>>> X_test, y_test = load_unit_test(split="test") # doctest: +SKIP | ||
>>> trf = RocketPyts(num_kernels=512) # doctest: +SKIP | ||
>>> trf.fit(X_train) # doctest: +SKIP | ||
Rocket(...) | ||
>>> X_train = trf.transform(X_train) # doctest: +SKIP | ||
>>> X_test = trf.transform(X_test) # doctest: +SKIP | ||
""" | ||
|
||
_tags = { | ||
# packaging info | ||
# -------------- | ||
"authors": "fkiraly", | ||
"python_dependencies": "pyts", | ||
# estimator type | ||
# -------------- | ||
"univariate-only": True, | ||
"fit_is_empty": False, | ||
"scitype:transform-input": "Series", | ||
# what is the scitype of X: Series, or Panel | ||
"scitype:transform-output": "Primitives", | ||
# what is the scitype of y: None (not needed), Primitives, Series, Panel | ||
"scitype:instancewise": False, # is this an instance-wise transform? | ||
} | ||
|
||
# defines the name of the attribute containing the pyts estimator | ||
_estimator_attr = "_pyts_rocket" | ||
|
||
def _get_pyts_class(self): | ||
"""Get pyts class. | ||
should import and return pyts class | ||
""" | ||
from pyts.transformation.rocket import ROCKET | ||
|
||
return ROCKET | ||
|
||
def __init__( | ||
self, | ||
n_kernels=10_000, | ||
kernel_sizes=(7, 9, 11), | ||
random_state=None, | ||
): | ||
self.n_kernels = n_kernels | ||
self.kernel_sizes = kernel_sizes | ||
self.random_state = random_state | ||
|
||
super().__init__() | ||
|
||
@classmethod | ||
def get_test_params(cls, parameter_set="default"): | ||
"""Return testing parameter settings for the estimator. | ||
Parameters | ||
---------- | ||
parameter_set : str, default="default" | ||
Name of the set of test parameters to return, for use in tests. If no | ||
special parameters are defined for a value, will return `"default"` set. | ||
Returns | ||
------- | ||
params : dict or list of dict, default = {} | ||
Parameters to create testing instances of the class | ||
Each dict are parameters to construct an "interesting" test instance, i.e., | ||
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. | ||
`create_test_instance` uses the first (or only) dictionary in `params` | ||
""" | ||
params1 = {"n_kernels": 234, "kernel_sizes": (5, 4)} | ||
params2 = {"n_kernels": 512, "kernel_sizes": (6, 7, 8)} | ||
return [params1, params2] |