-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add FilterFeaturesTransform * Update changelog * Refactor transform method * Fix tests, add checking on values in columns * Fix bugs, change error message * Fix docstring mistakes * Update changelog * Update etna-ts -> etna in changelog * Reformat code
- Loading branch information
1 parent
a4df644
commit fa981ba
Showing
4 changed files
with
176 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Optional | ||
from typing import Sequence | ||
|
||
import pandas as pd | ||
|
||
from etna.transforms.base import Transform | ||
|
||
|
||
class FilterFeaturesTransform(Transform): | ||
"""Filters features in each segment of the dataframe.""" | ||
|
||
def __init__(self, include: Optional[Sequence[str]] = None, exclude: Optional[Sequence[str]] = None): | ||
"""Create instance of FilterFeaturesTransform. | ||
Parameters | ||
---------- | ||
include: | ||
list of columns to pass through filter | ||
exclude: | ||
list of columns to not pass through | ||
Raises | ||
------ | ||
ValueError: | ||
if both options set or non of them | ||
""" | ||
if include is not None and exclude is None: | ||
self.include = list(set(include)) | ||
self.exclude = None | ||
elif exclude is not None and include is None: | ||
self.include = None | ||
self.exclude = list(set(exclude)) | ||
else: | ||
raise ValueError("There should be exactly one option set: include or exclude") | ||
|
||
def fit(self, df: pd.DataFrame) -> "FilterFeaturesTransform": | ||
"""Fit method does nothing and is kept for compatibility. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe with data. | ||
Returns | ||
------- | ||
result: FilterFeaturesTransform | ||
""" | ||
return self | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Filter features according to include/exclude parameters. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe with data to transform. | ||
Returns | ||
------- | ||
result: pd.Dataframe | ||
transformed dataframe | ||
""" | ||
result = df.copy() | ||
features = df.columns.get_level_values("feature") | ||
if self.include is not None: | ||
if not set(self.include).issubset(features): | ||
raise ValueError(f"Features {set(self.include) - set(features)} are not present in the dataset.") | ||
segments = sorted(set(df.columns.get_level_values("segment"))) | ||
result = result.loc[:, pd.IndexSlice[segments, self.include]] | ||
if self.exclude is not None and self.exclude: | ||
if not set(self.exclude).issubset(features): | ||
raise ValueError(f"Features {set(self.exclude) - set(features)} are not present in the dataset.") | ||
result = result.drop(columns=self.exclude, level="feature") | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from etna.datasets import TSDataset | ||
from etna.transforms import FilterFeaturesTransform | ||
|
||
|
||
@pytest.fixture | ||
def ts_with_features() -> TSDataset: | ||
timestamp = pd.date_range("2020-01-01", periods=100, freq="D") | ||
df_1 = pd.DataFrame({"timestamp": timestamp, "segment": "segment_1", "target": 1}) | ||
df_2 = pd.DataFrame({"timestamp": timestamp, "segment": "segment_2", "target": 2}) | ||
df = TSDataset.to_dataset(pd.concat([df_1, df_2], ignore_index=False)) | ||
|
||
df_exog_1 = pd.DataFrame({"timestamp": timestamp, "segment": "segment_1", "exog_1": 1, "exog_2": 2}) | ||
df_exog_2 = pd.DataFrame({"timestamp": timestamp, "segment": "segment_2", "exog_1": 3, "exog_2": 4}) | ||
df_exog = TSDataset.to_dataset(pd.concat([df_exog_1, df_exog_2], ignore_index=False)) | ||
|
||
return TSDataset(df=df, df_exog=df_exog, freq="D") | ||
|
||
|
||
def test_set_only_include(): | ||
"""Test that transform is created with include.""" | ||
_ = FilterFeaturesTransform(include=["exog_1", "exog_2"]) | ||
|
||
|
||
def test_set_only_exclude(): | ||
"""Test that transform is created with exclude.""" | ||
_ = FilterFeaturesTransform(exclude=["exog_1", "exog_2"]) | ||
|
||
|
||
def test_set_include_and_exclude(): | ||
"""Test that transform is not created with include and exclude.""" | ||
with pytest.raises(ValueError, match="There should be exactly one option set: include or exclude"): | ||
_ = FilterFeaturesTransform(include=["exog_1"], exclude=["exog_2"]) | ||
|
||
|
||
def test_set_none(): | ||
"""Test that transform is not created without include and exclude.""" | ||
with pytest.raises(ValueError, match="There should be exactly one option set: include or exclude"): | ||
_ = FilterFeaturesTransform() | ||
|
||
|
||
@pytest.mark.parametrize("include", [[], ["target"], ["exog_1"], ["exog_1", "exog_2", "target"]]) | ||
def test_include_filter(ts_with_features, include): | ||
"""Test that transform remains only features in include.""" | ||
original_df = ts_with_features.to_pandas() | ||
transform = FilterFeaturesTransform(include=include) | ||
ts_with_features.fit_transform([transform]) | ||
df_transformed = ts_with_features.to_pandas() | ||
expected_columns = set(include) | ||
got_columns = set(df_transformed.columns.get_level_values("feature")) | ||
assert got_columns == expected_columns | ||
segments = ts_with_features.segments | ||
for column in got_columns: | ||
assert np.all( | ||
df_transformed.loc[:, pd.IndexSlice[segments, column]] | ||
== original_df.loc[:, pd.IndexSlice[segments, column]] | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"exclude, expected_columns", | ||
[ | ||
([], ["target", "exog_1", "exog_2"]), | ||
(["target"], ["exog_1", "exog_2"]), | ||
(["exog_1", "exog_2"], ["target"]), | ||
(["target", "exog_1", "exog_2"], []), | ||
], | ||
) | ||
def test_exclude_filter(ts_with_features, exclude, expected_columns): | ||
"""Test that transform removes only features in exclude.""" | ||
original_df = ts_with_features.to_pandas() | ||
transform = FilterFeaturesTransform(exclude=exclude) | ||
ts_with_features.fit_transform([transform]) | ||
df_transformed = ts_with_features.to_pandas() | ||
got_columns = set(df_transformed.columns.get_level_values("feature")) | ||
assert got_columns == set(expected_columns) | ||
segments = ts_with_features.segments | ||
for column in got_columns: | ||
assert np.all( | ||
df_transformed.loc[:, pd.IndexSlice[segments, column]] | ||
== original_df.loc[:, pd.IndexSlice[segments, column]] | ||
) | ||
|
||
|
||
def test_include_filter_wrong_column(ts_with_features): | ||
"""Test that transform raises error with non-existent column in include.""" | ||
transform = FilterFeaturesTransform(include=["non-existent-column"]) | ||
with pytest.raises(ValueError, match="Features {.*} are not present in the dataset"): | ||
ts_with_features.fit_transform([transform]) | ||
|
||
|
||
def test_exclude_filter_wrong_column(ts_with_features): | ||
"""Test that transform raises error with non-existent column in exclude.""" | ||
transform = FilterFeaturesTransform(exclude=["non-existent-column"]) | ||
with pytest.raises(ValueError, match="Features {.*} are not present in the dataset"): | ||
ts_with_features.fit_transform([transform]) |