Skip to content

Fix SegmentEncoderTransform to pass inference tests #1103

Merged
merged 5 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103))
-
-
## [1.14.0] - 2022-12-16
Expand Down
28 changes: 25 additions & 3 deletions etna/transforms/encoders/segment_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import reprlib

import numpy as np
import pandas as pd
from sklearn import preprocessing

Expand Down Expand Up @@ -44,12 +47,31 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
-------
:
result dataframe

Raises
------
ValueError:
If transform isn't fitted.
ValueError:
If there are segments that weren't present during training.
"""
encoded_matrix = self._le.transform(self._le.classes_)
encoded_matrix = encoded_matrix.reshape(len(self._le.classes_), -1).repeat(len(df), axis=1).T
segments = df.columns.get_level_values("segment").unique().tolist()

try:
new_segments = set(segments) - set(self._le.classes_)
except AttributeError:
raise ValueError("The transform isn't fitted!")

if len(new_segments) > 0:
raise ValueError(
f"This transform can't process segments that weren't present on train data: {reprlib.repr(new_segments)}"
)

encoded_matrix = self._le.transform(segments)
encoded_matrix = np.tile(encoded_matrix, (len(df), 1))
encoded_df = pd.DataFrame(
encoded_matrix,
columns=pd.MultiIndex.from_product([self._le.classes_, ["segment_code"]], names=("segment", "feature")),
columns=pd.MultiIndex.from_product([segments, ["segment_code"]], names=("segment", "feature")),
index=df.index,
)
encoded_df = encoded_df.astype("category")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import pytest

from etna.transforms import SegmentEncoderTransform
from tests.test_transforms.utils import assert_transformation_equals_loaded_original
Expand All @@ -21,6 +22,38 @@ def test_segment_encoder_transform(dummy_df):
assert codes == {0, 1}, "Codes are not 0 and 1"


def test_subset_segments(dummy_df):
train_df = dummy_df
test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = SegmentEncoderTransform()

transform.fit(train_df)
transformed_test_df = transform.transform(test_df)

segments = sorted(transformed_test_df.columns.get_level_values("segment").unique())
features = sorted(transformed_test_df.columns.get_level_values("feature").unique())
assert segments == ["Omsk"]
assert features == ["segment_code", "target"]
values = transformed_test_df.loc[:, pd.IndexSlice[:, "segment_code"]]
assert np.all(values == values.iloc[0])


def test_not_fitted_error(dummy_df):
encoder = SegmentEncoderTransform()
with pytest.raises(ValueError, match="The transform isn't fitted"):
encoder.transform(dummy_df)


def test_new_segments_error(dummy_df):
train_df = dummy_df.loc[:, pd.IndexSlice["Moscow", :]]
test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = SegmentEncoderTransform()

transform.fit(train_df)
with pytest.raises(ValueError, match="This transform can't process segments that weren't present on train data"):
_ = transform.transform(test_df)


def test_save_load(example_tsds):
transform = SegmentEncoderTransform()
assert_transformation_equals_loaded_original(transform=transform, ts=example_tsds)