-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH]
sklearn
facing coercion utility for pd.DataFrame
, to str
…
…columns (#5550) This PR adds a `sklearn` facing coercion utility for `pd.DataFrame`, to `str` columns, and tests for it. As this is a second small module with `sklearn` themed functionality, the current module and its tests are moved one folder down together with the new content.
- Loading branch information
Showing
6 changed files
with
86 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Sklearn related utility functionality.""" | ||
|
||
from sktime.utils.sklearn._adapt_df import prep_skl_df | ||
from sktime.utils.sklearn._scitype import ( | ||
is_sklearn_classifier, | ||
is_sklearn_clusterer, | ||
is_sklearn_estimator, | ||
is_sklearn_regressor, | ||
is_sklearn_transformer, | ||
sklearn_scitype, | ||
) | ||
|
||
__all__ = [ | ||
"prep_skl_df", | ||
"is_sklearn_estimator", | ||
"is_sklearn_transformer", | ||
"is_sklearn_classifier", | ||
"is_sklearn_regressor", | ||
"is_sklearn_clusterer", | ||
"sklearn_scitype", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Utility functions for adapting to sklearn.""" | ||
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
|
||
import numpy as np | ||
|
||
|
||
def prep_skl_df(df, copy_df=False): | ||
"""Make df compatible with sklearn input expectations. | ||
Changes: | ||
turns column index into a list of strings | ||
Parameters | ||
---------- | ||
df : pd.DataFrame | ||
list of indices to sample from | ||
copy_df : bool, default=False | ||
whether to mutate df or return a copy | ||
if False, index of df is mutated | ||
if True, original df is not mutated. If index is not a list of strings, | ||
a copy is made and the copy is mutated. Otherwise, the original df is returned. | ||
""" | ||
cols = df.columns | ||
str_cols = cols.astype(str) | ||
|
||
if not np.all(str_cols == cols): | ||
if copy_df: | ||
df = df.copy() | ||
df.columns = str_cols | ||
|
||
return df |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tests for sklearn related utility functionality.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""Tests for sklearn dataframe coercion.""" | ||
|
||
__author__ = ["fkiraly"] | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from sktime.utils.sklearn._adapt_df import prep_skl_df | ||
|
||
|
||
@pytest.mark.parametrize("copy_df", [True, False]) | ||
def test_prep_skl_df_coercion(copy_df): | ||
"""Test that prep_skl_df behaves correctly on the coercion case.""" | ||
mixed_example = pd.DataFrame({0: [1, 2, 3], "b": [1, 2, 3]}) | ||
|
||
res = prep_skl_df(mixed_example, copy_df=copy_df) | ||
|
||
assert np.all(res.columns == ["0", "b"]) | ||
|
||
if not copy_df: | ||
assert res is mixed_example | ||
|
||
|
||
@pytest.mark.parametrize("copy_df", [True, False]) | ||
def test_prep_skl_df_non_coercion(copy_df): | ||
"""Test that prep_skl_df behaves correctly on the non-coercion case.""" | ||
mixed_example = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) | ||
|
||
res = prep_skl_df(mixed_example, copy_df=copy_df) | ||
|
||
assert np.all(res.columns == ["a", "b"]) | ||
assert res is mixed_example |
File renamed without changes.