-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add class * Add fix * Upd * Add relevance * Add repr * Fix fixture typing
- Loading branch information
1 parent
a51a363
commit 35db08a
Showing
6 changed files
with
68 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from etna.analysis.feature_relevance.relevance import RelevanceTable | ||
from etna.analysis.feature_relevance.relevance import StatisticsRelevanceTable | ||
from etna.analysis.feature_relevance.relevance_table import get_statistics_relevance_table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
import pandas as pd | ||
|
||
from etna.analysis.feature_relevance.relevance_table import get_statistics_relevance_table | ||
from etna.core.mixins import BaseMixin | ||
|
||
|
||
class RelevanceTable(ABC, BaseMixin): | ||
"""Abstract class for relevance table computation.""" | ||
|
||
def __init__(self, greater_is_better: bool): | ||
"""Init RelevanceTable. | ||
Parameters | ||
---------- | ||
greater_is_better: | ||
bool flag, if True the biggest value in relevance table corresponds to the most important exog feature | ||
""" | ||
self.greater_is_better = greater_is_better | ||
|
||
@abstractmethod | ||
def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, **kwargs) -> pd.DataFrame: | ||
"""Compute relevance table. | ||
For each series in df compute relevance of corresponding series in df_exog. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe with series that will be used as target | ||
df_exog: | ||
dataframe with series to compute relevance for df | ||
Returns | ||
------- | ||
relevance table: pd.DataFrame | ||
dataframe of shape n_segment x n_exog_series, relevance_table[i][j] contains relevance of j-th df_exog series to i-th df series | ||
""" | ||
pass | ||
|
||
|
||
class StatisticsRelevanceTable(RelevanceTable): | ||
"""StatisticsRelevanceTable builds feature relevance table with tsfresh statistics.""" | ||
|
||
def __init__(self): | ||
super().__init__(greater_is_better=False) | ||
|
||
def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, **kwargs) -> pd.DataFrame: | ||
"""Compute feature relevance table with etna.analysis.get_statistics_relevance_table method.""" | ||
table = get_statistics_relevance_table(df=df, df_exog=df_exog) | ||
return table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from etna.analysis.feature_relevance import StatisticsRelevanceTable | ||
|
||
|
||
def test_statistics_relevance_table(simple_df_relevance): | ||
rt = StatisticsRelevanceTable() | ||
assert not rt.greater_is_better | ||
df, df_exog = simple_df_relevance | ||
assert rt(df=df, df_exog=df_exog).shape == (2, 2) |