diff --git a/CHANGELOG.md b/CHANGELOG.md index 2db5e1b52..854e628fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed -- +- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110)) - ## [1.15.0] - 2023-01-31 diff --git a/etna/transforms/feature_selection/gale_shapley.py b/etna/transforms/feature_selection/gale_shapley.py index 61b3b003c..e0bc491d6 100644 --- a/etna/transforms/feature_selection/gale_shapley.py +++ b/etna/transforms/feature_selection/gale_shapley.py @@ -371,7 +371,7 @@ def fit(self, df: pd.DataFrame) -> "GaleShapleyFeatureSelectionTransform": segment_features_ranking=segment_features_ranking, feature_segments_ranking=feature_segments_ranking, ) - if step == gale_shapley_steps_number - 1: + if step == gale_shapley_steps_number - 1 and last_step_features_number != 0: selected_features = self._process_last_step( matches=matches, relevance_table=relevance_table, diff --git a/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py b/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py index ba92e786d..891b77f8f 100644 --- a/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py +++ b/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py @@ -19,6 +19,32 @@ from tests.test_transforms.utils import assert_transformation_equals_loaded_original +@pytest.fixture +def ts_with_exog_galeshapley(random_seed) -> TSDataset: + np.random.seed(random_seed) + + periods = 30 + df_1 = pd.DataFrame({"timestamp": pd.date_range("2020-01-15", periods=periods)}) + df_1["segment"] = "segment_1" + df_1["target"] = np.random.uniform(10, 20, size=periods) + + df_2 = pd.DataFrame({"timestamp": pd.date_range("2020-01-15", periods=periods)}) + df_2["segment"] = "segment_2" + df_2["target"] = np.random.uniform(-15, 5, size=periods) + + df = pd.concat([df_1, df_2]).reset_index(drop=True) + df = TSDataset.to_dataset(df) + tsds = TSDataset(df, freq="D") + df = tsds.to_pandas(flatten=True) + df_exog = df.copy().drop(columns=["target"]) + df_exog["weekday"] = df_exog["timestamp"].dt.weekday + df_exog["monthday"] = df_exog["timestamp"].dt.day + df_exog["month"] = df_exog["timestamp"].dt.month + df_exog["year"] = df_exog["timestamp"].dt.year + ts = TSDataset(df=TSDataset.to_dataset(df), df_exog=TSDataset.to_dataset(df_exog), freq="D") + return ts + + @pytest.fixture def ts_with_large_regressors_number(random_seed) -> TSDataset: df = generate_periodic_df(periods=100, start_time="2020-01-01", n_segments=3, period=7, scale=10) @@ -622,3 +648,14 @@ def test_work_with_non_regressors(ts_with_exog): ) def test_save_load(transform, ts_with_large_regressors_number): assert_transformation_equals_loaded_original(transform=transform, ts=ts_with_large_regressors_number) + + +def test_right_number_features_with_integer_division(ts_with_exog_galeshapley): + top_k = len(ts_with_exog_galeshapley.segments) + transform = GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=top_k) + + transform.fit(ts_with_exog_galeshapley.to_pandas()) + df = transform.transform(ts_with_exog_galeshapley.to_pandas()) + + remaining_columns = df.columns.get_level_values("feature").unique().tolist() + assert len(remaining_columns) == top_k + 1