diff --git a/CHANGELOG.md b/CHANGELOG.md index 80a5759ef..db776cfd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - -- +- Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103)) - - ## [1.14.0] - 2022-12-16 diff --git a/etna/transforms/encoders/segment_encoder.py b/etna/transforms/encoders/segment_encoder.py index e899b8eac..09b9fbe70 100644 --- a/etna/transforms/encoders/segment_encoder.py +++ b/etna/transforms/encoders/segment_encoder.py @@ -1,3 +1,6 @@ +import reprlib + +import numpy as np import pandas as pd from sklearn import preprocessing @@ -44,12 +47,31 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: ------- : result dataframe + + Raises + ------ + ValueError: + If transform isn't fitted. + ValueError: + If there are segments that weren't present during training. """ - encoded_matrix = self._le.transform(self._le.classes_) - encoded_matrix = encoded_matrix.reshape(len(self._le.classes_), -1).repeat(len(df), axis=1).T + segments = df.columns.get_level_values("segment").unique().tolist() + + try: + new_segments = set(segments) - set(self._le.classes_) + except AttributeError: + raise ValueError("The transform isn't fitted!") + + if len(new_segments) > 0: + raise ValueError( + f"This transform can't process segments that weren't present on train data: {reprlib.repr(new_segments)}" + ) + + encoded_matrix = self._le.transform(segments) + encoded_matrix = np.tile(encoded_matrix, (len(df), 1)) encoded_df = pd.DataFrame( encoded_matrix, - columns=pd.MultiIndex.from_product([self._le.classes_, ["segment_code"]], names=("segment", "feature")), + columns=pd.MultiIndex.from_product([segments, ["segment_code"]], names=("segment", "feature")), index=df.index, ) encoded_df = encoded_df.astype("category") diff --git a/tests/test_transforms/test_encoders/test_segment_encoder_transform.py b/tests/test_transforms/test_encoders/test_segment_encoder_transform.py index 8a8891bb3..5a599d70f 100644 --- a/tests/test_transforms/test_encoders/test_segment_encoder_transform.py +++ b/tests/test_transforms/test_encoders/test_segment_encoder_transform.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import pytest from etna.transforms import SegmentEncoderTransform from tests.test_transforms.utils import assert_transformation_equals_loaded_original @@ -21,6 +22,38 @@ def test_segment_encoder_transform(dummy_df): assert codes == {0, 1}, "Codes are not 0 and 1" +def test_subset_segments(dummy_df): + train_df = dummy_df + test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]] + transform = SegmentEncoderTransform() + + transform.fit(train_df) + transformed_test_df = transform.transform(test_df) + + segments = sorted(transformed_test_df.columns.get_level_values("segment").unique()) + features = sorted(transformed_test_df.columns.get_level_values("feature").unique()) + assert segments == ["Omsk"] + assert features == ["segment_code", "target"] + values = transformed_test_df.loc[:, pd.IndexSlice[:, "segment_code"]] + assert np.all(values == values.iloc[0]) + + +def test_not_fitted_error(dummy_df): + encoder = SegmentEncoderTransform() + with pytest.raises(ValueError, match="The transform isn't fitted"): + encoder.transform(dummy_df) + + +def test_new_segments_error(dummy_df): + train_df = dummy_df.loc[:, pd.IndexSlice["Moscow", :]] + test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]] + transform = SegmentEncoderTransform() + + transform.fit(train_df) + with pytest.raises(ValueError, match="This transform can't process segments that weren't present on train data"): + _ = transform.transform(test_df) + + def test_save_load(example_tsds): transform = SegmentEncoderTransform() assert_transformation_equals_loaded_original(transform=transform, ts=example_tsds)