Skip to content

Commit

Permalink
Merge pull request #3476 from pycaret/stlf
Browse files Browse the repository at this point in the history
added STLForecaster
  • Loading branch information
ngupta23 committed Apr 21, 2023
2 parents f1615f8 + 849d9a6 commit 8648314
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
60 changes: 60 additions & 0 deletions pycaret/containers/models/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,66 @@ def _set_tune_distributions(self) -> dict:
}


class STLFContainer(TimeSeriesContainer):
model_type = TSModelTypes.CLASSICAL

def __init__(self, experiment) -> None:
self.logger = get_logger()
np.random.seed(experiment.seed)
self.gpu_imported = False

from sktime.forecasting.trend import STLForecaster

# Disable container if certain features are not supported but enforced ----
dummy = STLForecaster()
self.active = _check_enforcements(forecaster=dummy, experiment=experiment)
if not self.active:
return

self.seasonality_present = experiment.seasonality_present
self.sp = experiment.primary_sp_to_use
self.strictly_positive = experiment.strictly_positive

if self.sp == 1:
self.active = False
return

args = self._set_args
tune_args = self._set_tune_args
tune_grid = self._set_tune_grid
tune_distributions = self._set_tune_distributions
leftover_parameters_to_categorical_distributions(tune_grid, tune_distributions)

super().__init__(
id="stlf",
name="STLF",
class_def=STLForecaster,
args=args,
tune_grid=tune_grid,
tune_distribution=tune_distributions,
tune_args=tune_args,
is_gpu_enabled=self.gpu_imported,
)

@property
def _set_args(self) -> Dict[str, Any]:
args = {"sp": self.sp}
return args

@property
def _set_tune_grid(self) -> Dict[str, List[Any]]:
# TODO: There may be other hyperparameters to tune. Check:
# http://www.sktime.net/en/latest/api_reference/auto_generated/sktime.forecasting.trend.STLForecaster.html
tune_grid = {
"sp": [self.sp],
"seasonal_deg": [0, 1],
"trend_deg": [0, 1],
"low_pass_deg": [0, 1],
"robust": [True, False],
}
return tune_grid


#################################
# REGRESSION BASED MODELS ####
#################################
Expand Down
1 change: 1 addition & 0 deletions pycaret/time_series/forecasting/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ def create_model(
* 'arima' - ARIMA family of models (ARIMA, SARIMA, SARIMAX)
* 'auto_arima' - Auto ARIMA
* 'exp_smooth' - Exponential Smoothing
* 'stlf' - STL Forecaster
* 'croston' - Croston Forecaster
* 'ets' - ETS
* 'theta' - Theta Forecaster
Expand Down
1 change: 1 addition & 0 deletions pycaret/time_series/forecasting/oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,7 @@ def create_model(
* 'arima' - ARIMA family of models (ARIMA, SARIMA, SARIMAX)
* 'auto_arima' - Auto ARIMA
* 'exp_smooth' - Exponential Smoothing
* 'stlf' - STL Forecaster
* 'croston' - Croston Forecaster
* 'ets' - ETS
* 'theta' - Theta Forecaster
Expand Down

0 comments on commit 8648314

Please sign in to comment.