Skip to content

Commit

Permalink
[ENH] adapter to pyts KNeighborsClassifier (#5939)
Browse files Browse the repository at this point in the history
Adapter for `pyts.classification.KNeighborsClassifier`, using the
generic adapter introduced in #5851.
Serves as a test case for classifiers, and possibly fixes
#5914.

Depends on #5851 for the adapter.
  • Loading branch information
fkiraly committed Feb 18, 2024
1 parent 2442e5a commit b5027a4
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api_reference/classification.rst
Expand Up @@ -96,6 +96,7 @@ Distance-based

ElasticEnsemble
KNeighborsTimeSeriesClassifier
KNeighborsTimeSeriesClassifierPyts
ProximityForest
ProximityStump
ProximityTree
Expand Down
4 changes: 4 additions & 0 deletions sktime/classification/distance_based/__init__.py
Expand Up @@ -5,6 +5,7 @@
"ProximityForest",
"ProximityStump",
"KNeighborsTimeSeriesClassifier",
"KNeighborsTimeSeriesClassifierPyts",
"ShapeDTW",
]

Expand All @@ -18,3 +19,6 @@
from sktime.classification.distance_based._time_series_neighbors import (
KNeighborsTimeSeriesClassifier,
)
from sktime.classification.distance_based._time_series_neighbors_pyts import (
KNeighborsTimeSeriesClassifierPyts,
)
149 changes: 149 additions & 0 deletions sktime/classification/distance_based/_time_series_neighbors_pyts.py
@@ -0,0 +1,149 @@
"""K-nearest neighbors time series classifier, from pyts."""
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

__author__ = ["fkiraly"]
__all__ = ["KNeighborsTimeSeriesClassifierPyts"]

from sktime.base.adapters._pyts import _PytsAdapter
from sktime.classification.base import BaseClassifier


class KNeighborsTimeSeriesClassifierPyts(_PytsAdapter, BaseClassifier):
"""K-nearest neighbors time series classifier, from ``pyts``.
Direct interface to ``pyts.classification.KNeighborsClassifier``,
author of the interfaced class is ``johannfaouzi``.
Parameters
----------
n_neighbors : int, optional (default = 1)
Number of neighbors to use.
weights : str or callable, optional (default = 'uniform')
weight function used in prediction. Possible values:
- 'uniform' : uniform weights. All points in each neighborhood
are weighted equally.
- 'distance' : weight points by the inverse of their distance.
in this case, closer neighbors of a query point will have a
greater influence than neighbors which are further away.
- [callable] : a user-defined function which accepts an
array of distances, and returns an array of the same shape
containing the weights.
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
Algorithm used to compute the nearest neighbors. Ignored ff ``metric``
is either 'dtw', 'dtw_sakoechiba', 'dtw_itakura', 'dtw_multiscale',
'dtw_fast' or 'boss' ('brute' will be used).
Note: fitting on sparse input will override the setting of
this parameter, using brute force.
leaf_size : int, optional (default = 30)
Leaf size passed to BallTree or KDTree. This can affect the
speed of the construction and query, as well as the memory
required to store the tree. The optimal value depends on the
nature of the problem.
metric : string or DistanceMetric object (default = 'minkowski')
The distance metric to use for the tree. The default metric is
minkowski, and with p=2 is equivalent to the standard Euclidean
metric. See the documentation of the DistanceMetric class from
scikit-learn for a list of available metrics.
For Dynamic Time Warping, the available metrics are 'dtw',
'dtw_sakoechiba', 'dtw_itakura', 'dtw_multiscale' and 'dtw_fast'.
For BOSS metric, one can use 'boss'.
p : integer, optional (default = 2)
Power parameter for the Minkowski metric. When p = 1, this is
equivalent to using manhattan_distance (l1), and euclidean_distance
(l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.
metric_params : dict, optional (default = None)
Additional keyword arguments for the metric function.
n_jobs : int, optional (default = 1)
The number of parallel jobs to run for neighbors search.
If ``n_jobs=-1``, then the number of jobs is set to the number of CPU
cores. Doesn't affect :meth:`fit` method.
Attributes
----------
classes_ : array, shape = (n_classes,)
An array of class labels known to the classifier.
"""

_tags = {
# packaging info
# --------------
"authors": "fkiraly",
"python_dependencies": "pyts",
# estimator type
# --------------
"capability:multivariate": False,
"capability:unequal_length": False,
"capability:missing_values": True,
"capability:predict_proba": True,
"classifier_type": "distance",
}

# defines the name of the attribute containing the pyts estimator
_estimator_attr = "_pyts_rocket"

def _get_pyts_class(self):
"""Get pyts class.
should import and return pyts class
"""
from pyts.classification import KNeighborsClassifier

return KNeighborsClassifier

def __init__(
self,
n_neighbors=1,
weights="uniform",
algorithm="auto",
leaf_size=30,
p=2,
metric="minkowski",
metric_params=None,
n_jobs=1,
):
self.n_neighbors = n_neighbors
self.weights = weights
self.algorithm = algorithm
self.leaf_size = leaf_size
self.p = p
self.metric = metric
self.metric_params = metric_params
self.n_jobs = n_jobs

super().__init__()

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params1 = {}
params2 = {
"n_neighbors": 3,
"weights": "distance",
"metric": "dtw_fast",
}
return [params1, params2]

0 comments on commit b5027a4

Please sign in to comment.