Skip to content

Commit

Permalink
[ENH] sklearn facing coercion utility for pd.DataFrame, to str
Browse files Browse the repository at this point in the history
…columns (#5550)

This PR adds a `sklearn` facing coercion utility for `pd.DataFrame`, to
`str` columns, and tests for it.

As this is a second small module with `sklearn` themed functionality,
the current module and its tests are moved one folder down together with
the new content.
  • Loading branch information
fkiraly committed Nov 24, 2023
1 parent ea98ce0 commit 7c3cc4b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 0 deletions.
21 changes: 21 additions & 0 deletions sktime/utils/sklearn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Sklearn related utility functionality."""

from sktime.utils.sklearn._adapt_df import prep_skl_df
from sktime.utils.sklearn._scitype import (
is_sklearn_classifier,
is_sklearn_clusterer,
is_sklearn_estimator,
is_sklearn_regressor,
is_sklearn_transformer,
sklearn_scitype,
)

__all__ = [
"prep_skl_df",
"is_sklearn_estimator",
"is_sklearn_transformer",
"is_sklearn_classifier",
"is_sklearn_regressor",
"is_sklearn_clusterer",
"sklearn_scitype",
]
31 changes: 31 additions & 0 deletions sktime/utils/sklearn/_adapt_df.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Utility functions for adapting to sklearn."""
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

import numpy as np


def prep_skl_df(df, copy_df=False):
"""Make df compatible with sklearn input expectations.
Changes:
turns column index into a list of strings
Parameters
----------
df : pd.DataFrame
list of indices to sample from
copy_df : bool, default=False
whether to mutate df or return a copy
if False, index of df is mutated
if True, original df is not mutated. If index is not a list of strings,
a copy is made and the copy is mutated. Otherwise, the original df is returned.
"""
cols = df.columns
str_cols = cols.astype(str)

if not np.all(str_cols == cols):
if copy_df:
df = df.copy()
df.columns = str_cols

return df
File renamed without changes.
1 change: 1 addition & 0 deletions sktime/utils/sklearn/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for sklearn related utility functionality."""
33 changes: 33 additions & 0 deletions sktime/utils/sklearn/tests/test_sklearn_df_adapt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Tests for sklearn dataframe coercion."""

__author__ = ["fkiraly"]

import numpy as np
import pandas as pd
import pytest

from sktime.utils.sklearn._adapt_df import prep_skl_df


@pytest.mark.parametrize("copy_df", [True, False])
def test_prep_skl_df_coercion(copy_df):
"""Test that prep_skl_df behaves correctly on the coercion case."""
mixed_example = pd.DataFrame({0: [1, 2, 3], "b": [1, 2, 3]})

res = prep_skl_df(mixed_example, copy_df=copy_df)

assert np.all(res.columns == ["0", "b"])

if not copy_df:
assert res is mixed_example


@pytest.mark.parametrize("copy_df", [True, False])
def test_prep_skl_df_non_coercion(copy_df):
"""Test that prep_skl_df behaves correctly on the non-coercion case."""
mixed_example = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]})

res = prep_skl_df(mixed_example, copy_df=copy_df)

assert np.all(res.columns == ["a", "b"])
assert res is mixed_example
File renamed without changes.

0 comments on commit 7c3cc4b

Please sign in to comment.