Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add support for logistic growth to Prophet #1419

Merged
68 changes: 67 additions & 1 deletion darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
import re
from typing import List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand All @@ -28,6 +28,12 @@ def __init__(
country_holidays: Optional[str] = None,
suppress_stdout_stderror: bool = True,
add_encoders: Optional[dict] = None,
cap: Union[
float, Callable[[Union[pd.DatetimeIndex, pd.RangeIndex]], Sequence[float]]
] = None,
floor: Union[
float, Callable[[Union[pd.DatetimeIndex, pd.RangeIndex]], Sequence[float]]
] = None,
**prophet_kwargs,
):
"""Facebook Prophet
Expand Down Expand Up @@ -92,6 +98,26 @@ def __init__(
'transformer': Scaler()
}
..
cap
Parameter specifiying the maximum carrying capacity when predicting with logistic growth.
Mandatory when `growth = 'logistic'`, otherwise ignored.
See <https://facebook.github.io/prophet/docs/saturating_forecasts.html> for more information
on logistic forecasts.
Can be either

- a number, for constant carrying capacities
- a function taking a DatetimeIndex or RangeIndex and returning a corresponding a Sequence of numbers,
where each number indicates the carrying capacity at this index.
floor
Parameter specifiying the minimum carrying capacity when predicting logistic growth.
Optional when `growth = 'logistic'` (defaults to 0), otherwise ignored.
See <https://facebook.github.io/prophet/docs/saturating_forecasts.html> for more information
on logistic forecasts.
Can be either

- a number, for constant carrying capacities
- a function taking a DatetimeIndex or RangeIndex and returning a corresponding a Sequence of numbers,
where each number indicates the carrying capacity at this index.
prophet_kwargs
Some optional keyword arguments for Prophet.
For information about the parameters see:
Expand Down Expand Up @@ -119,6 +145,26 @@ def __init__(
self._execute_and_suppress_output = execute_and_suppress_output
self._model_builder = prophet.Prophet

self._cap = cap
self._floor = floor
self.is_logistic = (
"growth" in prophet_kwargs and prophet_kwargs["growth"] == "logistic"
)
if not self.is_logistic and (cap is not None or floor is not None):
logger.warning(
"Parameters `cap` and/or `floor` were set although `growth` is not "
"logistic. The set capacities will be ignored."
)
if self.is_logistic:
raise_if(
cap is None,
"Parameter `cap` has to be set when `growth` is logistic",
logger,
)
if floor is None:
# Use 0 as default value
self._floor = 0

def __str__(self):
return "Prophet"

Expand All @@ -131,6 +177,8 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
fit_df = pd.DataFrame(
data={"ds": series.time_index, "y": series.univariate_values()}
)
if self.is_logistic:
fit_df = self._add_capacities_to_df(fit_df)

self.model = self._model_builder(**self.prophet_kwargs)

Expand Down Expand Up @@ -188,13 +236,31 @@ def _predict(

return self._build_forecast_series(forecast)

def _add_capacities_to_df(self, df: pd.DataFrame) -> pd.DataFrame:
dates = df["ds"]
try:
df["cap"] = self._cap(dates) if callable(self._cap) else self._cap
df["floor"] = self._floor(dates) if callable(self._floor) else self._floor
except ValueError as e:
raise_if(
"does not match length of index" in str(e),
"Callables supplied to `Prophet.set_capacity` as `cap` or `floor` "
"arguments have to return Sequences of identical length as their "
" input argument Sequence!",
logger,
)
raise
return df

def _generate_predict_df(
self, n: int, future_covariates: Optional[TimeSeries] = None
) -> pd.DataFrame:
"""Returns a pandas DataFrame in the format required for Prophet.predict() with `n` dates after the end of
the fitted TimeSeries"""

predict_df = pd.DataFrame(data={"ds": self._generate_new_dates(n)})
if self.is_logistic:
predict_df = self._add_capacities_to_df(predict_df)
if future_covariates is not None:
predict_df = predict_df.merge(
future_covariates.pd_dataframe(),
Expand Down
21 changes: 21 additions & 0 deletions darts/tests/models/forecasting/test_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,27 @@ def test_prophet_model_default_with_prophet_constructor(self):
model = Prophet()
assert model._model_builder == FBProphet, "model should use Facebook Prophet"

def test_prophet_model_with_logistic_growth(self):
model = Prophet(growth="logistic", cap=1)

# Create timeseries with logistic function
times = tg.generate_index(
pd.Timestamp("20200101"), pd.Timestamp("20210101"), freq="D"
)
values = np.linspace(-10, 10, len(times))
f = np.vectorize(lambda x: 1 / (1 + np.exp(-x)))
values = f(values)
ts = TimeSeries.from_times_and_values(times, values, freq="D")
# split in the middle, so the only way of predicting the plateau correctly
# is using the capacity
train, val = ts.split_after(0.5)

model.fit(train)
pred = model.predict(len(val))

for val_i, pred_i in zip(val.univariate_values(), pred.univariate_values()):
self.assertAlmostEqual(val_i, pred_i, delta=0.1)

def helper_test_freq_coversion(self, test_cases):
for freq, period in test_cases.items():
ts_sine = tg.sine_timeseries(
Expand Down