Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add support for logistic growth to Prophet #1419

Merged
68 changes: 67 additions & 1 deletion darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
import re
from typing import List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -98,6 +98,12 @@ def __init__(
self._execute_and_suppress_output = execute_and_suppress_output
self._model_builder = prophet.Prophet

self.is_logistic = False
if "growth" in prophet_kwargs and prophet_kwargs["growth"] == "logistic":
DavidKleindienst marked this conversation as resolved.
Show resolved Hide resolved
self.is_logistic = True
self._cap = None
self._floor = None

def __str__(self):
return "Prophet"

Expand All @@ -110,6 +116,14 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
fit_df = pd.DataFrame(
data={"ds": series.time_index, "y": series.univariate_values()}
)
if self.is_logistic:
raise_if(
DavidKleindienst marked this conversation as resolved.
Show resolved Hide resolved
self._cap is None or self._floor is None,
"Cap and floor have to be set by calling `Prophet.set_capacity` "
"before fitting, when parameter `growth` is set to 'logistic'.",
logger,
)
fit_df = self._add_capacities_to_df(fit_df)

self.model = self._model_builder(**self.prophet_kwargs)

Expand Down Expand Up @@ -167,13 +181,31 @@ def _predict(

return self._build_forecast_series(forecast)

def _add_capacities_to_df(self, df: pd.DataFrame) -> pd.DataFrame:
dates = df["ds"]
try:
df["cap"] = self._cap(dates) if callable(self._cap) else self._cap
df["floor"] = self._floor(dates) if callable(self._floor) else self._floor
except ValueError as e:
raise_if(
"does not match length of index" in str(e),
"Callables supplied to `Prophet.set_capacity` as `cap` or `floor` "
"arguments have to return Sequences of identical length as their "
" input argument Sequence!",
logger,
)
raise
return df

def _generate_predict_df(
self, n: int, future_covariates: Optional[TimeSeries] = None
) -> pd.DataFrame:
"""Returns a pandas DataFrame in the format required for Prophet.predict() with `n` dates after the end of
the fitted TimeSeries"""

predict_df = pd.DataFrame(data={"ds": self._generate_new_dates(n)})
if self.is_logistic:
predict_df = self._add_capacities_to_df(predict_df)
if future_covariates is not None:
predict_df = predict_df.merge(
future_covariates.pd_dataframe(),
Expand Down Expand Up @@ -265,6 +297,40 @@ def add_seasonality(
}
self._store_add_seasonality_call(seasonality_call=function_call)

def set_capacity(
self,
cap: Union[
float, Callable[[Union[pd.DatetimeIndex, pd.RangeIndex]], Sequence[float]]
],
floor: Union[
float, Callable[[Union[pd.DatetimeIndex, pd.RangeIndex]], Sequence[float]]
] = 0,
) -> None:
"""Set carrying capacities for predicting with logistic growth.
These capacities are only used when `Prophet` was instantiated with `growth = 'logistic'`
See <https://facebook.github.io/prophet/docs/saturating_forecasts.html> for more information
on logistic forecasts.

The `cap` and `floor` parameters may be:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add a blank line for the bullet-point list to render correctly in HTML (you can try building & rendering the doc with ./gradlew buildDocs to check the HTML files directly).

- a number, for constant carrying capacities
- a function taking a DatetimeIndex or RangeIndex and returning a corresponding a Sequence of numbers,
where each number indicates the carrying capacity at this index.

Parameters
----------
cap
The maximum carrying capacity
floor
The minimum carrying capacity, by default 0
"""
if not self.is_logistic:
logger.warning(
"Capacities were set although `growth` is not logistic. "
"The set capacities will be ignored."
)
self._cap = cap
self._floor = floor

def _store_add_seasonality_call(
self, seasonality_call: Optional[dict] = None
) -> None:
Expand Down
22 changes: 22 additions & 0 deletions darts/tests/models/forecasting/test_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,28 @@ def test_prophet_model_default_with_prophet_constructor(self):
model = Prophet()
assert model._model_builder == FBProphet, "model should use Facebook Prophet"

def test_prophet_model_with_logistic_growth(self):
model = Prophet(growth="logistic")
model.set_capacity(1, 0)

# Create timeseries with logistic function
times = tg.generate_index(
pd.Timestamp("20200101"), pd.Timestamp("20210101"), freq="D"
)
values = np.linspace(-10, 10, len(times))
f = np.vectorize(lambda x: 1 / (1 + np.exp(-x)))
values = f(values)
ts = TimeSeries.from_times_and_values(times, values, freq="D")
# split in the middle, so the only way of predicting the plateau correctly
# is using the capacity
train, val = ts.split_after(0.5)

model.fit(train)
pred = model.predict(len(val))

for val_i, pred_i in zip(val.univariate_values(), pred.univariate_values()):
self.assertAlmostEqual(val_i, pred_i, delta=0.1)

def helper_test_freq_coversion(self, test_cases):
for freq, period in test_cases.items():
ts_sine = tg.sine_timeseries(
Expand Down