-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GaleShapley feature selection (#284)
* Gale-Shapley init * MVP Gale-Shapley * Fix style * Fix greater_is_better * Add tests * Upd classes * Fix bug with endless while loop * Delete test * Upd CHANGELOG * Upd lock
- Loading branch information
1 parent
9cd56ab
commit a4df644
Showing
7 changed files
with
1,036 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from abc import ABC | ||
from typing import List | ||
|
||
import pandas as pd | ||
|
||
from etna.transforms import Transform | ||
|
||
|
||
class BaseFeatureSelectionTransform(Transform, ABC): | ||
"""Base class for feature selection transforms.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.selected_regressors = [] | ||
|
||
@staticmethod | ||
def _get_regressors(df: pd.DataFrame) -> List[str]: | ||
"""Get list of regressors in the dataframe.""" | ||
result = set() | ||
for column in df.columns.get_level_values("feature"): | ||
if column.startswith("regressor_"): | ||
result.add(column) | ||
return sorted(list(result)) | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Select top_k regressors. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe with all segments data | ||
Returns | ||
------- | ||
result: pd.DataFrame | ||
Dataframe with with only selected regressors | ||
""" | ||
result = df.copy() | ||
selected_columns = sorted( | ||
[ | ||
column | ||
for column in df.columns.get_level_values("feature").unique() | ||
if not column.startswith("regressor_") or column in self.selected_regressors | ||
] | ||
) | ||
result = result.loc[:, pd.IndexSlice[:, selected_columns]] | ||
return result |
Oops, something went wrong.