diff --git a/CHANGELOG.md b/CHANGELOG.md index 717682764..16ef5a347 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update mypy checking and fix issues with it ([#248](https://github.com/tinkoff-ai/etna-ts/pull/248)) - Add histogram method in outliers notebook ([#252](https://github.com/tinkoff-ai/etna-ts/pull/252)) - Joblib parameters for backtest and ensembles ([#253](https://github.com/tinkoff-ai/etna-ts/pull/253)) +- Replace cycle over segments with vectorized expression in TSDataset._check_endings ([#264](https://github.com/tinkoff-ai/etna-ts/pull/264)) ### Fixed - Fixed broken links in docs command section ([#223](https://github.com/tinkoff-ai/etna-ts/pull/223)) diff --git a/etna/datasets/tsdataset.py b/etna/datasets/tsdataset.py index ae56e84af..fa32362d3 100644 --- a/etna/datasets/tsdataset.py +++ b/etna/datasets/tsdataset.py @@ -249,9 +249,8 @@ def _merge_exog(self, df: pd.DataFrame) -> pd.DataFrame: def _check_endings(self): """Check that all targets ends at the same timestamp.""" max_index = self.df.index.max() - for segment in self.df.columns.get_level_values("segment"): - if np.isnan(self.df.loc[max_index, pd.IndexSlice[segment, "target"]]): - raise ValueError(f"All segments should end at the same timestamp") + if np.any(pd.isna(self.df.loc[max_index, pd.IndexSlice[:, "target"]])): + raise ValueError(f"All segments should end at the same timestamp") def inverse_transform(self): """Apply inverse transform method of transforms to the data. diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index ed5e6241d..97c6082b2 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -52,7 +52,8 @@ def df_and_regressors() -> Tuple[pd.DataFrame, pd.DataFrame]: return df, df_exog -def test_same_ending_error_raise(): +def test_check_endings_error_raise(): + """Check that _check_endings method raises exception if some segments end with nan.""" timestamp = pd.date_range("2021-01-01", "2021-02-01") df1 = pd.DataFrame({"timestamp": timestamp, "target": 11, "segment": "1"}) df2 = pd.DataFrame({"timestamp": timestamp[:-5], "target": 12, "segment": "2"}) @@ -61,17 +62,18 @@ def test_same_ending_error_raise(): ts = TSDataset(df=df, freq="D") with pytest.raises(ValueError): - ts.fit_transform([]) + ts._check_endings() -def test_same_ending_error_pass(): +def test_check_endings_error_pass(): + """Check that _check_endings method passes if there is no nans at the end of all segments.""" timestamp = pd.date_range("2021-01-01", "2021-02-01") df1 = pd.DataFrame({"timestamp": timestamp, "target": 11, "segment": "1"}) df2 = pd.DataFrame({"timestamp": timestamp, "target": 12, "segment": "2"}) df = pd.concat([df1, df2], ignore_index=True) df = TSDataset.to_dataset(df) ts = TSDataset(df=df, freq="D") - ts.fit_transform([]) + ts._check_endings() def test_categorical_after_call_to_pandas():