Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time series classifiers refactor/Shape_DTW #1554

Merged
merged 11 commits into from Oct 26, 2021
43 changes: 12 additions & 31 deletions sktime/classification/distance_based/_shape_dtw.py
Expand Up @@ -6,26 +6,22 @@

import numpy as np
import pandas as pd
from sktime.utils.validation.panel import check_X, check_X_y
from sktime.datatypes._panel._convert import from_nested_to_2d_array

# Tuning
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV, KFold

# Transforms
from sktime.transformations.panel.segment import SlidingWindowSegmenter
# Classifiers
from sktime.classification.base import BaseClassifier
from sktime.classification.distance_based import KNeighborsTimeSeriesClassifier
from sktime.datatypes._panel._convert import from_nested_to_2d_array
from sktime.transformations.panel.dictionary_based._paa import PAA
from sktime.transformations.panel.dwt import DWTTransformer
from sktime.transformations.panel.slope import SlopeTransformer
from sktime.transformations.panel.summarize._extract import (
DerivativeSlopeTransformer,
)
from sktime.transformations.panel.hog1d import HOG1DTransformer

# Classifiers
from sktime.classification.base import BaseClassifier
from sktime.classification.distance_based import KNeighborsTimeSeriesClassifier
# Transforms
from sktime.transformations.panel.segment import SlidingWindowSegmenter
from sktime.transformations.panel.slope import SlopeTransformer
from sktime.transformations.panel.summarize._extract import DerivativeSlopeTransformer

__author__ = ["Vincent Nicholson"]

Expand Down Expand Up @@ -114,15 +110,6 @@ class ShapeDTW(BaseClassifier):

"""

# Capability tags
capabilities = {
"multivariate": False,
"unequal_length": False,
"missing_values": False,
"train_estimate": False,
"contractable": False,
}

def __init__(
self,
n_neighbours=1,
Expand All @@ -138,7 +125,7 @@ def __init__(
self.metric_params = metric_params
super(ShapeDTW, self).__init__()

def fit(self, X, y):
def _fit(self, X, y):
"""Train the classifier.

Parameters
Expand All @@ -159,8 +146,6 @@ def fit(self, X, y):
+ "' instead."
)

X, y = check_X_y(X, y, enforce_univariate=False)

if self.metric_params is None:
self.metric_params = {}

Expand Down Expand Up @@ -261,7 +246,7 @@ def _preprocess(self, X):

return X

def predict_proba(self, X):
def _predict_proba(self, X):
"""Perform predictions on the testing data X.

This function returns the probabilities for each class.
Expand All @@ -275,15 +260,13 @@ def predict_proba(self, X):
output : numpy array of shape =
[n_instances, num_classes] of probabilities
"""
X = check_X(X, enforce_univariate=False)

# Transform the test data in the same way as the training data.
X = self._preprocess(X)

# Classify the test data
return self.knn.predict_proba(X)

def predict(self, X):
def _predict(self, X):
"""Find predictions for all cases in X.

Parameters
Expand All @@ -294,8 +277,6 @@ def predict(self, X):
-------
output : numpy array of shape = [n_instances]
"""
X = check_X(X, enforce_univariate=False)

# Transform the test data in the same way as the training data.
X = self._preprocess(X)

Expand Down