Skip to content

Commit

Permalink
Merge pull request #3518 from pycaret/revert_ts_type
Browse files Browse the repository at this point in the history
Enhanced seasonality type detection
  • Loading branch information
ngupta23 committed Apr 30, 2023
2 parents e3ee67f + 408dd9c commit f300f02
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pycaret/internal/plots/utils/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def decomp_subplot(
except ValueError as exception:
logger.warning(exception)
logger.warning(
"Seasonal Decompose plot failed most likely sue to missing data"
"Seasonal Decompose plot failed most likely due to missing data"
)
return fig, None
elif plot == "decomp_stl":
Expand Down
48 changes: 21 additions & 27 deletions pycaret/time_series/forecasting/oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sktime.transformations.compose import TransformerPipeline
from sktime.transformations.series.impute import Imputer
from sktime.utils.seasonality import autocorrelation_seasonality_test
from statsmodels.tsa.seasonal import seasonal_decompose

from pycaret.containers.metrics import get_all_ts_metric_containers
from pycaret.containers.models import get_all_ts_model_containers
Expand All @@ -36,7 +37,7 @@
from pycaret.internal.distributions import get_base_distributions
from pycaret.internal.logging import get_logger, redirect_output
from pycaret.internal.parallel.parallel_backend import ParallelBackend
from pycaret.internal.plots.time_series import _get_plot, plot_time_series_decomposition
from pycaret.internal.plots.time_series import _get_plot
from pycaret.internal.plots.utils.time_series import (
_clean_model_results_labels,
_get_data_types_to_plot,
Expand Down Expand Up @@ -1167,8 +1168,11 @@ def _set_seasonal_type(self) -> "TSForecastingExperiment":
"""Sets the seasonal type to be used in the models.
Decomposes data using additive and multiplicative seasonal decomposition
Then selects the seasonality type that has the least amount of variance
in the residuals.
Then selects the seasonality type based on seasonality strength per FPP
(https://otexts.com/fpp2/seasonal-strength.html).
NOTE: For Multiplicative, the denominator multiplies the seasonal and residual
components instead of adding them. Rest of the calculations remain the same.
Returns
-------
Expand All @@ -1184,37 +1188,27 @@ def _set_seasonal_type(self) -> "TSForecastingExperiment":
)
)

_, data_add = plot_time_series_decomposition(
data=data_to_use,
plot="decomp",
data_kwargs={
"seasonal_period": self.primary_sp_to_use,
"type": "additive",
},
fig_defaults={"template": "plotly_dark", "width": 1000, "height": 600},
data_label="sth",
decomp_add = seasonal_decompose(
data_to_use, period=self.primary_sp_to_use, model="additive"
)
_, data_mult = plot_time_series_decomposition(
data=data_to_use,
plot="decomp",
data_kwargs={
"seasonal_period": self.primary_sp_to_use,
"type": "multiplicative",
},
fig_defaults={"template": "plotly_dark", "width": 1000, "height": 600},
data_label="sth",
decomp_mult = seasonal_decompose(
data_to_use, period=self.primary_sp_to_use, model="multiplicative"
)

if data_add is None or data_mult is None:
# None is retuirned when decomposition fails
if decomp_add is None or decomp_mult is None:
# None is returned when decomposition fails
# Default to "add" since mul can give issues
seasonality_type = "add"
else:
key = list(data_mult.get("decomp").keys())[0]
std_add = np.std(data_add.get("decomp")[key].resid)
std_mul = np.std(data_mult.get("decomp")[key].resid)
var_r_add = (np.std(decomp_add.resid)) ** 2
var_rs_add = (np.std(decomp_add.resid + decomp_add.seasonal)) ** 2
var_r_mult = (np.std(decomp_mult.resid)) ** 2
var_rs_mult = (np.std(decomp_mult.resid * decomp_mult.seasonal)) ** 2

Fs_add = np.maximum(1 - var_r_add / var_rs_add, 0)
Fs_mult = np.maximum(1 - var_r_mult / var_rs_mult, 0)

if std_mul < std_add:
if Fs_mult > Fs_add:
seasonality_type = "mul"
else:
seasonality_type = "add"
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ipython>=5.5.0
ipywidgets>=7.6.5 # required by pycaret.internal.display
tqdm>=4.62.0 # required by pycaret.internal.display
numpy>=1.21, <1.24 # Can't >=1.24 because of np.float deprecation
pandas>=1.3.0, <1.6.0 # Can't >=1.6 because of sktime
pandas>=1.3.0, <2.0.0 # Can't >=2.0.0 because of sktime
jinja2>=1.2 # Required by pycaret.internal.utils --> pandas.io.formats.style
scipy<2.0.0 # Can't >=2.0.0 due to sktime
joblib>=1.2.0 # joblib<1.2.0 is vulnerable to Arbitrary Code Execution (https://github.com/advisories/GHSA-6hrg-qmvc-2xh8)
Expand Down Expand Up @@ -35,6 +35,6 @@ plotly-resampler>=0.8.3.1

# Time-series
statsmodels>=0.12.1
sktime>=0.16.1,<=0.17.0 # Limited until this is fixed: https://github.com/sktime/sktime/issues/4468
sktime>=0.16.1,!=0.17.1,<0.17.2 # Due to this bug in 0.17.1: https://github.com/sktime/sktime/issues/4468
tbats>=1.1.3
pmdarima>=1.8.0,!=1.8.1,<3.0.0 # Matches sktime
57 changes: 57 additions & 0 deletions tests/test_time_series_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,60 @@ def test_hyperparameter_splits():
exp2.setup(data=data, hyperparameter_split="train", fh=FH, fold=FOLD)

assert exp1.uppercase_d != exp2.uppercase_d


@pytest.mark.parametrize("index", ["RangeIndex", "DatetimeIndex"])
def test_seasonality_type(index: str):
"""Tests the detection of the seasonality type
Tests various index types and tests for both additive and multiplicative
seasonality.
Parameters
----------
index : str
Type of index. Options are: "RangeIndex" and "DatetimeIndex"
"""
# Create base data
N = 100
y_trend = np.arange(100, 100 + N)
y_season = 100 * (1 + np.sin(y_trend)) # No negative values when creating final y
y = pd.Series(y_trend + y_season)

# RangeIndex is default index
if index == "DatetimeIndex":
dates = pd.date_range(start="2020-01-01", periods=N, freq="MS")
y.index = dates

_test_seasonality_type(y)


def _test_seasonality_type(y):
# -------------------------------------------------------------------------#
# Test 1: Additive Seasonality
# -------------------------------------------------------------------------#
err_msg = "Expected additive seasonality, got multiplicative"
exp = TSForecastingExperiment()
exp.setup(data=y, session_id=42)
assert exp.seasonality_type == "add", err_msg

# # -------------------------------------------------------------------------#
# # Test 2A: Multiplicative Seasonality (1)
# # -------------------------------------------------------------------------#
# y = pd.Series(y_trend * y_season)
# y.index = dates

# err_msg = "Expected multiplicative seasonality, got additive (1)"
# exp = TSForecastingExperiment()
# exp.setup(data=y, session_id=42)
# assert exp.seasonality_type == "mul", err_msg

# -------------------------------------------------------------------------#
# Test 2B: Multiplicative Seasonality (2)
# -------------------------------------------------------------------------#
y = get_data("airline", verbose=False)

err_msg = "Expected multiplicative seasonality, got additive (2)"
exp = TSForecastingExperiment()
exp.setup(data=y, session_id=42)
assert exp.seasonality_type == "mul", err_msg

0 comments on commit f300f02

Please sign in to comment.