Skip to content

Fix SegmentEncoderTransform to pass inference tests #1103

Merged
merged 5 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103))
-
-
## [1.14.0] - 2022-12-16
Expand Down
25 changes: 22 additions & 3 deletions etna/transforms/encoders/segment_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,31 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
-------
:
result dataframe

Raises
------
ValueError:
If transform isn't fitted.
ValueError:
If there are segments that weren't present during training.
"""
encoded_matrix = self._le.transform(self._le.classes_)
encoded_matrix = encoded_matrix.reshape(len(self._le.classes_), -1).repeat(len(df), axis=1).T
segments = df.columns.get_level_values("segment").unique().tolist()

try:
new_segments = set(segments) - set(self._le.classes_)
except AttributeError:
raise ValueError("The transform isn't fitted!")

if len(new_segments) > 0:
raise ValueError(
f"This transform can't process segments that weren't present on train data: {new_segments}"
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
)

encoded_matrix = self._le.transform(segments)
encoded_matrix = encoded_matrix.reshape(len(segments), -1).repeat(len(df), axis=1).T
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
encoded_df = pd.DataFrame(
encoded_matrix,
columns=pd.MultiIndex.from_product([self._le.classes_, ["segment_code"]], names=("segment", "feature")),
columns=pd.MultiIndex.from_product([segments, ["segment_code"]], names=("segment", "feature")),
index=df.index,
)
encoded_df = encoded_df.astype("category")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import pytest

from etna.transforms import SegmentEncoderTransform
from tests.test_transforms.utils import assert_transformation_equals_loaded_original
Expand All @@ -21,6 +22,38 @@ def test_segment_encoder_transform(dummy_df):
assert codes == {0, 1}, "Codes are not 0 and 1"


def test_subset_segments(dummy_df):
train_df = dummy_df
test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = SegmentEncoderTransform()

transform.fit(train_df)
transformed_test_df = transform.transform(test_df)

segments = sorted(transformed_test_df.columns.get_level_values("segment").unique())
features = sorted(transformed_test_df.columns.get_level_values("feature").unique())
assert segments == ["Omsk"]
assert features == ["segment_code", "target"]
values = transformed_test_df.loc[:, pd.IndexSlice[:, "segment_code"]]
assert np.all(values == values.iloc[0])


def test_not_fitted_error(dummy_df):
encoder = SegmentEncoderTransform()
with pytest.raises(ValueError, match="The transform isn't fitted"):
encoder.transform(dummy_df)


def test_new_segments_error(dummy_df):
train_df = dummy_df.loc[:, pd.IndexSlice["Moscow", :]]
test_df = dummy_df.loc[:, pd.IndexSlice["Omsk", :]]
transform = SegmentEncoderTransform()

transform.fit(train_df)
with pytest.raises(ValueError, match="This transform can't process segments that weren't present on train data"):
_ = transform.transform(test_df)


def test_save_load(example_tsds):
transform = SegmentEncoderTransform()
assert_transformation_equals_loaded_original(transform=transform, ts=example_tsds)