Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] remove private imports from sklearn - set_random_state #4672

Merged
merged 4 commits into from Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions sktime/base/_base.py
Expand Up @@ -64,9 +64,9 @@ class name: BaseEstimator
from skbase.base import BaseObject as _BaseObject
from sklearn import clone
from sklearn.base import BaseEstimator as _BaseEstimator
from sklearn.ensemble._base import _set_random_states

from sktime.exceptions import NotFittedError
from sktime.utils.random_state import set_random_state


class BaseObject(_BaseObject):
Expand Down Expand Up @@ -508,6 +508,6 @@ def _clone_estimator(base_estimator, random_state=None):
estimator = clone(base_estimator)

if random_state is not None:
_set_random_states(estimator, random_state)
set_random_state(estimator, random_state)

return estimator
Expand Up @@ -269,7 +269,7 @@ def decide_prediction_safety(self, X, X_probabilities, state_info):

def _fit_estimator(self, X, y, i):
rs = 255 if self.random_state == 0 else self.random_state
rs = None if self.random_state is None else rs * 37 * (i + 1)
rs = None if self.random_state is None else rs * 37 * (i + 1) % 2**31
rng = check_random_state(rs)

estimator = _clone_estimator(
Expand Down
4 changes: 2 additions & 2 deletions sktime/classification/early_classification/_teaser.py
Expand Up @@ -352,7 +352,7 @@ def _get_next_idx(self, series_length):

def _fit_estimator(self, X, y, i):
rs = 255 if self.random_state == 0 else self.random_state
rs = None if self.random_state is None else rs * 37 * (i + 1)
rs = None if self.random_state is None else rs * 37 * (i + 1) % 2**31
rng = check_random_state(rs)

default = (
Expand Down Expand Up @@ -430,7 +430,7 @@ def _fit_estimator(self, X, y, i):

def _predict_proba_for_estimator(self, X, i):
rs = 255 if self.random_state == 0 else self.random_state
rs = None if self.random_state is None else rs * 37 * (i + 1)
rs = None if self.random_state is None else rs * 37 * (i + 1) % 2**31
rng = check_random_state(rs)

probas = self._estimators[i].predict_proba(
Expand Down
6 changes: 4 additions & 2 deletions sktime/classification/kernel_based/_arsenal.py
Expand Up @@ -229,7 +229,8 @@ def _fit(self, X, y):
if self.random_state is None
else (255 if self.random_state == 0 else self.random_state)
* 37
* (i + 1),
* (i + 1)
% 2**31,
),
X,
y,
Expand All @@ -253,7 +254,8 @@ def _fit(self, X, y):
if self.random_state is None
else (255 if self.random_state == 0 else self.random_state)
* 37
* (i + 1),
* (i + 1)
% 2**31,
),
X,
y,
Expand Down
2 changes: 1 addition & 1 deletion sktime/forecasting/compose/_bagging.py
Expand Up @@ -11,7 +11,6 @@
import pandas as pd
from sklearn import clone
from sklearn.utils import check_random_state
from sklearn.utils._testing import set_random_state

from sktime.datatypes._utilities import update_data
from sktime.forecasting.base import BaseForecaster
Expand All @@ -22,6 +21,7 @@
STLBootstrapTransformer,
)
from sktime.utils.estimators import MockForecaster
from sktime.utils.random_state import set_random_state


class BaggingForecaster(BaseForecaster):
Expand Down
4 changes: 2 additions & 2 deletions sktime/series_as_features/base/estimators/_ensemble.py
Expand Up @@ -14,7 +14,6 @@
from joblib import Parallel, delayed
from numpy import float64 as DOUBLE
from sklearn.base import clone
from sklearn.ensemble._base import _set_random_states
from sklearn.ensemble._forest import (
MAX_INT,
BaseForest,
Expand All @@ -25,6 +24,7 @@
from sklearn.utils import check_array, check_random_state, compute_sample_weight

from sktime.transformations.panel.summarize import RandomIntervalFeatureExtractor
from sktime.utils.random_state import set_random_state


def _parallel_build_trees(
Expand Down Expand Up @@ -112,7 +112,7 @@ def _make_estimator(self, append=True, random_state=None):
estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})

if random_state is not None:
_set_random_states(estimator, random_state)
set_random_state(estimator, random_state)

if append:
self.estimators_.append(estimator)
Expand Down
2 changes: 1 addition & 1 deletion sktime/tests/test_all_estimators.py
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.utils._testing import set_random_state
from sklearn.utils.estimator_checks import (
check_get_params_invariance as _check_get_params_invariance,
)
Expand Down Expand Up @@ -55,6 +54,7 @@
_list_required_methods,
)
from sktime.utils._testing.scenarios_getter import retrieve_scenarios
from sktime.utils.random_state import set_random_state
from sktime.utils.sampling import random_partition
from sktime.utils.validation._dependencies import (
_check_dl_dependencies,
Expand Down
37 changes: 37 additions & 0 deletions sktime/utils/random_state.py
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
"""Utilities for handling the random_state variable."""
# copied from scikit-learn to avoid dependency on sklearn private methods

import numpy as np
from sklearn.utils import check_random_state


def set_random_state(estimator, random_state=0):
"""Set fixed random_state parameters for an estimator.

Finds all parameters ending ``random_state`` and sets them to integers
derived from ``random_state``.

Parameters
----------
estimator : estimator supporting get_params, set_params
Estimator with potential randomness managed by random_state parameters.

random_state : int, RandomState instance or None, default=None
Pseudo-random number generator to control the generation of the random
integers. Pass an int for reproducible output across multiple function calls.

Notes
-----
This does not necessarily set *all* ``random_state`` attributes that
control an estimator's randomness, only those accessible through
``estimator.get_params()``.
"""
random_state = check_random_state(random_state)
to_set = {}
for key in sorted(estimator.get_params(deep=True)):
if key == "random_state" or key.endswith("__random_state"):
to_set[key] = random_state.randint(np.iinfo(np.int32).max)

if to_set:
estimator.set_params(**to_set)