Skip to content

Commit

Permalink
Properly process random_state when fitting Time Series Forest ensembl…
Browse files Browse the repository at this point in the history
…e in parallel (#819)

* Properly process random_state when fitting ensemble in parallel

* Fix test for a new random_state initialization

* Suggested changes

Co-authored-by: Oleksii Kachaiev <okachaiev@riotgames.com>
Co-authored-by: Markus Löning <markus.loning@gmail.com>
Co-authored-by: mloning <markus.loning.17@ucl.ac.uk>
Co-authored-by: Matthew Middlehurst <pfm15hbu@uea.ac.uk>
  • Loading branch information
5 people committed Jun 18, 2021
1 parent ad82a27 commit b64c3cb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 38 deletions.
20 changes: 16 additions & 4 deletions sktime/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,25 @@

import inspect

from sklearn import clone
from sklearn.base import BaseEstimator as _BaseEstimator
from sklearn.ensemble._base import _set_random_states

from sktime.exceptions import NotFittedError


class BaseEstimator(_BaseEstimator):
"""Base class for defining estimators in sktime. Extends scikit-learn's
BaseEstimator.
"""Base class for defining estimators in sktime.
Extends scikit-learn's BaseEstimator.
"""

def __init__(self):
self._is_fitted = False

@property
def is_fitted(self):
"""Has `fit` been called?"""
"""Whether `fit` has been called."""
return self._is_fitted

def check_is_fitted(self):
Expand All @@ -41,7 +44,7 @@ def check_is_fitted(self):

@classmethod
def _all_tags(cls):
"""Get tags from estimator class and all its parent classes"""
"""Get tags from estimator class and all its parent classes."""
# We here create a separate estimator tag interface in addition to the one in
# scikit-learn to make sure we do not interfere with scikit-learn's one
# when we inherit from scikit-learn classes. We also make estimator tags a
Expand All @@ -59,3 +62,12 @@ def _all_tags(cls):
collected_tags.update(more_tags)

return collected_tags


def _clone_estimator(base_estimator, random_state=None):
estimator = clone(base_estimator)

if random_state is not None:
_set_random_states(estimator, random_state)

return estimator
10 changes: 5 additions & 5 deletions sktime/classification/interval_based/tests/test_tsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
[
[1.0, 0.0],
[1.0, 0.0],
[0.9, 0.1],
[0.95, 0.05],
[1.0, 0.0],
[0.0, 1.0],
[0.95, 0.05],
[0.0, 1.0],
[0.9, 0.1],
[0.95, 0.05],
[1.0, 0.0],
[0.15, 0.85],
[0.8, 0.2],
[0.9, 0.1],
[1.0, 0.0],
[1.0, 0.0],
[0.25, 0.75],
[0.2, 0.8],
[1.0, 0.0],
[0.95, 0.05],
[0.9, 0.1],
[1.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
Expand Down
60 changes: 31 additions & 29 deletions sktime/series_as_features/base/estimators/interval_based/_tsf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# -*- coding: utf-8 -*-
"""Time Series Forest (TSF) Classifier."""

__author__ = ["Tony Bagnall", "kkoziara", "luiszugasti", "kanand77", "Markus Löning"]
__author__ = [
"Tony Bagnall",
"kkoziara",
"luiszugasti",
"kanand77",
"Markus Löning",
"Oleksii Kachaiev",
]
__all__ = [
"BaseTimeSeriesForest",
"_transform",
Expand All @@ -14,16 +21,17 @@
import numpy as np
from joblib import Parallel
from joblib import delayed
from sklearn.base import clone
from sklearn.utils.multiclass import class_distribution
from sklearn.utils.validation import check_random_state

from sktime.base._base import _clone_estimator
from sktime.utils.slope_and_trend import _slope
from sktime.utils.validation import check_n_jobs
from sktime.utils.validation.panel import check_X_y


class BaseTimeSeriesForest:
"""Base Time series forest classifier."""
"""Base time series forest classifier."""

# Capability tags
capabilities = {
Expand Down Expand Up @@ -66,17 +74,15 @@ def fit(self, X, y):
Parameters
----------
X : array-like or sparse matrix of shape = [n_instances,
series_length] or shape = [n_instances,n_columns]
The training input samples. If a Pandas data frame is passed it
must have a single column (i.e. univariate
classification. TSF has no bespoke method for multivariate
classification as yet.
y : array-like, shape = [n_instances] The class labels.
Xt: np.ndarray or pd.DataFrame
Panel training data.
y : np.ndarray
The class labels.
Returns
-------
self : object
An fitted instance of the classifier
"""
X, y = check_X_y(
X,
Expand All @@ -87,6 +93,8 @@ def fit(self, X, y):
X = X.squeeze(1)
n_instances, self.series_length = X.shape

n_jobs = check_n_jobs(self.n_jobs)

rng = check_random_state(self.random_state)

self.n_classes = np.unique(y).shape[0]
Expand All @@ -103,13 +111,9 @@ def fit(self, X, y):
for _ in range(self.n_estimators)
]

self.estimators_ = Parallel(n_jobs=self.n_jobs)(
self.estimators_ = Parallel(n_jobs=n_jobs)(
delayed(_fit_estimator)(
X,
y,
self.base_estimator,
self.intervals_[i],
self.random_state,
_clone_estimator(self.base_estimator, rng), X, y, self.intervals_[i]
)
for i in range(self.n_estimators)
)
Expand All @@ -119,16 +123,21 @@ def fit(self, X, y):


def _transform(X, intervals):
"""Compute the mean, std_dev and slope for given intervals of input data X.
"""Transform X for given intervals.
Compute the mean, standard deviation and slope for given intervals of input data X.
Parameters
----------
X (Array-like, int or float): Time series data X
intervals (Array-like, int or float): Time range intervals for series X
Xt: np.ndarray or pd.DataFrame
Panel data to transform.
intervals : np.ndarray
Intervals containing start and end values.
Returns
-------
int32 Array: transformed_x containing mean, std_deviation and slope
Xt: np.ndarray or pd.DataFrame
Transformed X, containing the mean, std and slope for each interval
"""
n_instances, _ = X.shape
n_intervals, _ = intervals.shape
Expand Down Expand Up @@ -157,14 +166,7 @@ def _get_intervals(n_intervals, min_interval, series_length, rng):
return intervals


def _fit_estimator(X, y, base_estimator, intervals, random_state=None):
"""Fit an estimator.
- a clone of base_estimator - on input data (X, y)
transformed using the randomly generated intervals.
"""
estimator = clone(base_estimator)
estimator.set_params(random_state=random_state)

def _fit_estimator(estimator, X, y, intervals):
"""Fit an estimator on input data (X, y)."""
transformed_x = _transform(X, intervals)
return estimator.fit(transformed_x, y)

0 comments on commit b64c3cb

Please sign in to comment.