Skip to content

Commit

Permalink
[ENH] sklearn 1.2.0 compatibility - remove private _check_weights
Browse files Browse the repository at this point in the history
… import in `KNeighborsTimeSeriesClassifier` and -`Regressor` (#3918)

This PR fixes a compatibility issue with `sklearn 1.2` and removes the private import `_check_weights` from `KNeighborsTimeSeriesClassifier` and `KNeighborsTimeSeriesRegressor`.

This can be done without deprecation or change in functionality, because `_check_weights` was just an erroneous leftover from an earlier version that used inheritance and not composition.

Right now, the `sklearn` classifier is wrapped as a component, so `_check_weights` (or its 1.2 equivalent) is called again inside the component - therefore, naive removal simply removes an unnecessary duplication.
  • Loading branch information
fkiraly committed Dec 19, 2022
1 parent 3ea2e35 commit 5d17b0b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
Expand Up @@ -24,7 +24,6 @@

import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors._base import _check_weights

from sktime.classification.base import BaseClassifier
from sktime.datatypes import check_is_mtype
Expand Down Expand Up @@ -139,7 +138,7 @@ def __init__(
n_jobs=None,
):
self.n_neighbors = n_neighbors
self.weights = _check_weights(weights)
self.weights = weights
self.algorithm = algorithm
self.distance = distance
self.distance_params = distance_params
Expand Down
3 changes: 1 addition & 2 deletions sktime/regression/distance_based/_time_series_neighbors.py
Expand Up @@ -13,7 +13,6 @@
__all__ = ["KNeighborsTimeSeriesRegressor"]

from sklearn.neighbors import KNeighborsRegressor
from sklearn.neighbors._base import _check_weights

from sktime.distances import pairwise_distance
from sktime.regression.base import BaseRegressor
Expand Down Expand Up @@ -132,7 +131,7 @@ def __init__(
leaf_size=leaf_size,
n_jobs=n_jobs,
)
self.weights = _check_weights(weights)
self.weights = weights

super(KNeighborsTimeSeriesRegressor, self).__init__()

Expand Down

0 comments on commit 5d17b0b

Please sign in to comment.