diff --git a/CHANGELOG.md b/CHANGELOG.md index 386ef9abe..9b7c79e5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add default `params_to_tune` for `HoltWintersModel`, `HoltModel` and `SimpleExpSmoothingModel` ([#1209](https://github.com/tinkoff-ai/etna/pull/1209)) - Add default `params_to_tune` for `RNNModel` and `MLPModel` ([#1218](https://github.com/tinkoff-ai/etna/pull/1218)) - Add default `params_to_tune` for `DateFlagsTransform`, `TimeFlagsTransform`, `SpecialDaysTransform` and `FourierTransform` ([#1228](https://github.com/tinkoff-ai/etna/pull/1228)) +- Add default `params_to_tune` for `MedianOutliersTransform`, `DensityOutliersTransform` and `PredictionIntervalOutliersTransform` ([#1231](https://github.com/tinkoff-ai/etna/pull/1231)) ### Fixed - Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110)) - `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157)) diff --git a/etna/transforms/outliers/point_outliers.py b/etna/transforms/outliers/point_outliers.py index 8b6dde739..f684048b7 100644 --- a/etna/transforms/outliers/point_outliers.py +++ b/etna/transforms/outliers/point_outliers.py @@ -5,6 +5,7 @@ from typing import Union import pandas as pd +from typing_extensions import Literal from etna import SETTINGS from etna.analysis import absolute_difference_distance @@ -18,6 +19,12 @@ if SETTINGS.prophet_required: from etna.models import ProphetModel +if SETTINGS.auto_required: + from optuna.distributions import BaseDistribution + from optuna.distributions import CategoricalDistribution + from optuna.distributions import IntUniformDistribution + from optuna.distributions import UniformDistribution + class MedianOutliersTransform(OutliersTransform): """Transform that uses :py:func:`~etna.analysis.outliers.median_outliers.get_anomalies_median` to find anomalies in data. @@ -59,6 +66,19 @@ def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]: """ return get_anomalies_median(ts=ts, in_column=self.in_column, window_size=self.window_size, alpha=self.alpha) + def params_to_tune(self) -> Dict[str, "BaseDistribution"]: + """Get default grid for tuning hyperparameters. + + Returns + ------- + : + Grid to tune. + """ + return { + "window_size": IntUniformDistribution(low=3, high=30), + "alpha": UniformDistribution(low=0.5, high=5), + } + class DensityOutliersTransform(OutliersTransform): """Transform that uses :py:func:`~etna.analysis.outliers.density_outliers.get_anomalies_density` to find anomalies in data. @@ -120,6 +140,20 @@ def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]: distance_func=self.distance_func, ) + def params_to_tune(self) -> Dict[str, "BaseDistribution"]: + """Get default grid for tuning hyperparameters. + + Returns + ------- + : + Grid to tune. + """ + return { + "window_size": IntUniformDistribution(low=3, high=30), + "distance_coef": UniformDistribution(low=0.5, high=5), + "n_neighbors": IntUniformDistribution(low=1, high=10), + } + class PredictionIntervalOutliersTransform(OutliersTransform): """Transform that uses :py:func:`~etna.analysis.outliers.prediction_interval_outliers.get_anomalies_prediction_interval` to find anomalies in data.""" @@ -127,7 +161,7 @@ class PredictionIntervalOutliersTransform(OutliersTransform): def __init__( self, in_column: str, - model: Union[Type["ProphetModel"], Type["SARIMAXModel"]], + model: Union[Literal["prophet"], Literal["sarimax"], Type["ProphetModel"], Type["SARIMAXModel"]], interval_width: float = 0.95, **model_kwargs, ): @@ -149,8 +183,20 @@ def __init__( self.model = model self.interval_width = interval_width self.model_kwargs = model_kwargs + self._model_type = self._get_model_type(model) super().__init__(in_column=in_column) + @staticmethod + def _get_model_type( + model: Union[Literal["prophet"], Literal["sarimax"], Type["ProphetModel"], Type["SARIMAXModel"]] + ) -> Union[Type["ProphetModel"], Type["SARIMAXModel"]]: + if isinstance(model, str): + if model == "prophet": + return ProphetModel + elif model == "sarimax": + return SARIMAXModel + return model + def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]: """Call :py:func:`~etna.analysis.outliers.prediction_interval_outliers.get_anomalies_prediction_interval` function with self parameters. @@ -166,12 +212,25 @@ def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]: """ return get_anomalies_prediction_interval( ts=ts, - model=self.model, + model=self._model_type, interval_width=self.interval_width, in_column=self.in_column, **self.model_kwargs, ) + def params_to_tune(self) -> Dict[str, "BaseDistribution"]: + """Get default grid for tuning hyperparameters. + + Returns + ------- + : + Grid to tune. + """ + return { + "interval_width": UniformDistribution(low=0.8, high=1.0), + "model": CategoricalDistribution(["prophet", "sarimax"]), + } + __all__ = [ "MedianOutliersTransform", diff --git a/tests/test_transforms/test_outliers/test_outliers_transform.py b/tests/test_transforms/test_outliers/test_outliers_transform.py index 70e7c45a8..ceb601441 100644 --- a/tests/test_transforms/test_outliers/test_outliers_transform.py +++ b/tests/test_transforms/test_outliers/test_outliers_transform.py @@ -11,6 +11,7 @@ from etna.transforms import DensityOutliersTransform from etna.transforms import MedianOutliersTransform from etna.transforms import PredictionIntervalOutliersTransform +from tests.test_transforms.utils import assert_sampling_is_valid from tests.test_transforms.utils import assert_transformation_equals_loaded_original from tests.utils import select_segments_subset @@ -226,3 +227,17 @@ def test_save_load(transform, outliers_solid_tsds): ) def test_save_load_prediction_interval(transform, outliers_solid_tsds): assert_transformation_equals_loaded_original(transform=transform, ts=outliers_solid_tsds) + + +@pytest.mark.parametrize( + "transform", + ( + MedianOutliersTransform(in_column="target"), + DensityOutliersTransform(in_column="target"), + PredictionIntervalOutliersTransform(in_column="target", model="sarimax"), + ), +) +def test_params_to_tune(transform, outliers_solid_tsds): + ts = outliers_solid_tsds + assert len(transform.params_to_tune()) > 0 + assert_sampling_is_valid(transform=transform, ts=ts)