Skip to content

Commit

Permalink
GaleShapley feature selection (#284)
Browse files Browse the repository at this point in the history
* Gale-Shapley init

* MVP Gale-Shapley

* Fix style

* Fix greater_is_better

* Add tests

* Upd classes

* Fix bug with endless while loop

* Delete test

* Upd CHANGELOG

* Upd lock
  • Loading branch information
julia-shenshina committed Nov 17, 2021
1 parent 9cd56ab commit a4df644
Show file tree
Hide file tree
Showing 7 changed files with 1,036 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- RelevanceTable returns rank ([#268](https://github.com/tinkoff-ai/etna/pull/268/))
- GaleShapleyFeatureSelectionTransform ([#284](https://github.com/tinkoff-ai/etna/pull/284))

### Changed
- Rename confidence interval to prediction interval, start working with quantiles instead of interval_width ([#285](https://github.com/tinkoff-ai/etna-ts/pull/285))
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from etna.transforms.detrend import LinearTrendTransform
from etna.transforms.detrend import TheilSenTrendTransform
from etna.transforms.feature_importance import TreeFeatureSelectionTransform
from etna.transforms.gale_shapley import GaleShapleyFeatureSelectionTransform
from etna.transforms.imputation import TimeSeriesImputerTransform
from etna.transforms.lags import LagTransform
from etna.transforms.log import LogTransform
Expand Down
42 changes: 3 additions & 39 deletions etna/transforms/feature_importance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

Expand All @@ -15,7 +14,7 @@
from sklearn.tree import ExtraTreeRegressor

from etna.datasets import TSDataset
from etna.transforms.base import Transform
from etna.transforms.feature_selection import BaseFeatureSelectionTransform

TreeBasedRegressor = Union[
DecisionTreeRegressor,
Expand All @@ -27,7 +26,7 @@
]


class TreeFeatureSelectionTransform(Transform):
class TreeFeatureSelectionTransform(BaseFeatureSelectionTransform):
"""Transform that selects regressors according to tree-based models feature importance."""

def __init__(self, model: TreeBasedRegressor, top_k: int):
Expand All @@ -44,19 +43,9 @@ def __init__(self, model: TreeBasedRegressor, top_k: int):
"""
if not isinstance(top_k, int) or top_k < 0:
raise ValueError("Parameter top_k should be positive integer")

super().__init__()
self.model = model
self.top_k = top_k
self.selected_regressors: Optional[List[str]] = None

@staticmethod
def _get_regressors(df: pd.DataFrame) -> List[str]:
"""Get list of regressors in the dataframe."""
result = set()
for column in df.columns.get_level_values("feature"):
if column.startswith("regressor_"):
result.add(column)
return sorted(list(result))

@staticmethod
def _get_train(df: pd.DataFrame) -> Tuple[np.array, np.array]:
Expand Down Expand Up @@ -105,28 +94,3 @@ def fit(self, df: pd.DataFrame) -> "TreeFeatureSelectionTransform":
weights = self._get_regressors_weights(df)
self.selected_regressors = self._select_top_k_regressors(weights, self.top_k)
return self

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Select top_k regressors.
Parameters
----------
df:
dataframe with all segments data
Returns
-------
result: pd.DataFrame
Dataframe with with only selected regressors
"""
result = df.copy()
selected_columns = sorted(
[
column
for column in df.columns.get_level_values("feature").unique()
if not column.startswith("regressor_") or column in self.selected_regressors
]
)
result = result.loc[:, pd.IndexSlice[:, selected_columns]]
return result
47 changes: 47 additions & 0 deletions etna/transforms/feature_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from abc import ABC
from typing import List

import pandas as pd

from etna.transforms import Transform


class BaseFeatureSelectionTransform(Transform, ABC):
"""Base class for feature selection transforms."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.selected_regressors = []

@staticmethod
def _get_regressors(df: pd.DataFrame) -> List[str]:
"""Get list of regressors in the dataframe."""
result = set()
for column in df.columns.get_level_values("feature"):
if column.startswith("regressor_"):
result.add(column)
return sorted(list(result))

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Select top_k regressors.
Parameters
----------
df:
dataframe with all segments data
Returns
-------
result: pd.DataFrame
Dataframe with with only selected regressors
"""
result = df.copy()
selected_columns = sorted(
[
column
for column in df.columns.get_level_values("feature").unique()
if not column.startswith("regressor_") or column in self.selected_regressors
]
)
result = result.loc[:, pd.IndexSlice[:, selected_columns]]
return result
Loading

0 comments on commit a4df644

Please sign in to comment.