Skip to content

Commit

Permalink
Replace cycle over segments with vectorized expression in `TSDataset.…
Browse files Browse the repository at this point in the history
…_check_endings` (#264)

* Replace cycle over segments with vectorized expression

* Update changelog

* Rename test names

* Replace np.isnan with pd.isna in _check_endings

* Remove performance sanity check
  • Loading branch information
Mr-Geekman committed Nov 8, 2021
1 parent 7d47c58 commit 14aab47
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update mypy checking and fix issues with it ([#248](https://github.com/tinkoff-ai/etna-ts/pull/248))
- Add histogram method in outliers notebook ([#252](https://github.com/tinkoff-ai/etna-ts/pull/252))
- Joblib parameters for backtest and ensembles ([#253](https://github.com/tinkoff-ai/etna-ts/pull/253))
- Replace cycle over segments with vectorized expression in TSDataset._check_endings ([#264](https://github.com/tinkoff-ai/etna-ts/pull/264))

### Fixed
- Fixed broken links in docs command section ([#223](https://github.com/tinkoff-ai/etna-ts/pull/223))
Expand Down
5 changes: 2 additions & 3 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,8 @@ def _merge_exog(self, df: pd.DataFrame) -> pd.DataFrame:
def _check_endings(self):
"""Check that all targets ends at the same timestamp."""
max_index = self.df.index.max()
for segment in self.df.columns.get_level_values("segment"):
if np.isnan(self.df.loc[max_index, pd.IndexSlice[segment, "target"]]):
raise ValueError(f"All segments should end at the same timestamp")
if np.any(pd.isna(self.df.loc[max_index, pd.IndexSlice[:, "target"]])):
raise ValueError(f"All segments should end at the same timestamp")

def inverse_transform(self):
"""Apply inverse transform method of transforms to the data.
Expand Down
10 changes: 6 additions & 4 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def df_and_regressors() -> Tuple[pd.DataFrame, pd.DataFrame]:
return df, df_exog


def test_same_ending_error_raise():
def test_check_endings_error_raise():
"""Check that _check_endings method raises exception if some segments end with nan."""
timestamp = pd.date_range("2021-01-01", "2021-02-01")
df1 = pd.DataFrame({"timestamp": timestamp, "target": 11, "segment": "1"})
df2 = pd.DataFrame({"timestamp": timestamp[:-5], "target": 12, "segment": "2"})
Expand All @@ -61,17 +62,18 @@ def test_same_ending_error_raise():
ts = TSDataset(df=df, freq="D")

with pytest.raises(ValueError):
ts.fit_transform([])
ts._check_endings()


def test_same_ending_error_pass():
def test_check_endings_error_pass():
"""Check that _check_endings method passes if there is no nans at the end of all segments."""
timestamp = pd.date_range("2021-01-01", "2021-02-01")
df1 = pd.DataFrame({"timestamp": timestamp, "target": 11, "segment": "1"})
df2 = pd.DataFrame({"timestamp": timestamp, "target": 12, "segment": "2"})
df = pd.concat([df1, df2], ignore_index=True)
df = TSDataset.to_dataset(df)
ts = TSDataset(df=df, freq="D")
ts.fit_transform([])
ts._check_endings()


def test_categorical_after_call_to_pandas():
Expand Down

0 comments on commit 14aab47

Please sign in to comment.