Skip to content

Commit

Permalink
Merge pull request #18 from vonum/cross-validation
Browse files Browse the repository at this point in the history
Cross validation and performance metrics integration
  • Loading branch information
vonum committed May 21, 2021
2 parents e8966fb + e3eb0ea commit 7af3dd4
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 0 deletions.
12 changes: 12 additions & 0 deletions multi_prophet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ def plot_components(self, forecasts, plotly=False, **kwargs):
for column, forecast in forecasts.items()
}

def cross_validation(self, horizon, **kwargs):
return {
column: model.cross_validation(horizon=horizon, **kwargs)
for column, model in self.model_pool.items()
}

def performance_metrics(self, horizon, **kwargs):
return {
column: model.performance_metrics(horizon=horizon, **kwargs)
for column, model in self.model_pool.items()
}

def _init_model_pool(self, columns, **kwargs):
return {c: Prophet(**kwargs) for c in columns}

Expand Down
1 change: 1 addition & 0 deletions multi_prophet/factories.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .prophet import Prophet
from .data_builder import DataFrameBuilder


def model_pool_factory(columns=None, config=None, regressors={}, **kwargs):
if config:
return _different_models_pool_factory(config, regressors)
Expand Down
1 change: 1 addition & 0 deletions multi_prophet/plots.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fbprophet.plot import plot_plotly, plot_components_plotly


def plotly_plot(model, forecast, **kwargs):
return plot_plotly(model, forecast, **kwargs)

Expand Down
8 changes: 8 additions & 0 deletions multi_prophet/prophet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import fbprophet
from fbprophet.diagnostics import cross_validation, performance_metrics
from . import plots


Expand Down Expand Up @@ -36,3 +37,10 @@ def plot_components(self, forecast, plotly=False, **kwargs):
return plots.plotly_components_plot(self.prophet, forecast, **kwargs)
else:
return self.prophet.plot_components(forecast)

def cross_validation(self, horizon, **kwargs):
return cross_validation(self.prophet, horizon=horizon, **kwargs)

def performance_metrics(self, horizon, **kwargs):
cv_df = cross_validation(self.prophet, horizon=horizon, **kwargs)
return performance_metrics(cv_df)
27 changes: 27 additions & 0 deletions tests/test_multi_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,30 @@ def test_components_plot(self):

for plot in plots:
self.assertIsInstance(plot, matplotlib.figure.Figure)

def test_cross_validation(self):
mp = multi_prophet.MultiProphet(columns=PREDICTOR_COLUMNS)
mp.fit(self.df)

cross_validation_dfs = mp.cross_validation(horizon="365 days")
for c, cross_validation_df in cross_validation_dfs.items():
self.assertTrue("ds" in cross_validation_df.columns)
self.assertTrue("yhat" in cross_validation_df.columns)
self.assertTrue("yhat_lower" in cross_validation_df.columns)
self.assertTrue("yhat_upper" in cross_validation_df.columns)
self.assertTrue("y" in cross_validation_df.columns)
self.assertTrue("cutoff" in cross_validation_df.columns)

def test_performance_metrics(self):
mp = multi_prophet.MultiProphet(columns=PREDICTOR_COLUMNS)
mp.fit(self.df)

performance_metrics_dfs = mp.performance_metrics(horizon="365 days")
for c, performance_metrics_df in performance_metrics_dfs.items():
self.assertTrue("horizon" in performance_metrics_df.columns)
self.assertTrue("mse" in performance_metrics_df.columns)
self.assertTrue("rmse" in performance_metrics_df.columns)
self.assertTrue("mae" in performance_metrics_df.columns)
self.assertTrue("mape" in performance_metrics_df.columns)
self.assertTrue("mdape" in performance_metrics_df.columns)
self.assertTrue("coverage" in performance_metrics_df.columns)
29 changes: 29 additions & 0 deletions tests/test_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,32 @@ def test_plotly_components_plot(self):
mp.plot_components(forecast, plotly=True),
plotly.graph_objs.Figure
)

def test_cross_validation(self):
mp = multi_prophet.Prophet()
mp.fit(self.df)

cross_validation_df = mp.cross_validation(horizon="365 days")
self.assertIsInstance(cross_validation_df, pd.DataFrame)

self.assertTrue("ds" in cross_validation_df.columns)
self.assertTrue("yhat" in cross_validation_df.columns)
self.assertTrue("yhat_lower" in cross_validation_df.columns)
self.assertTrue("yhat_upper" in cross_validation_df.columns)
self.assertTrue("y" in cross_validation_df.columns)
self.assertTrue("cutoff" in cross_validation_df.columns)

def test_performance_metrics(self):
mp = multi_prophet.Prophet()
mp.fit(self.df)

performance_metrics_df = mp.performance_metrics(horizon="365 days")
self.assertIsInstance(performance_metrics_df, pd.DataFrame)

self.assertTrue("horizon" in performance_metrics_df.columns)
self.assertTrue("mse" in performance_metrics_df.columns)
self.assertTrue("rmse" in performance_metrics_df.columns)
self.assertTrue("mae" in performance_metrics_df.columns)
self.assertTrue("mape" in performance_metrics_df.columns)
self.assertTrue("mdape" in performance_metrics_df.columns)
self.assertTrue("coverage" in performance_metrics_df.columns)

0 comments on commit 7af3dd4

Please sign in to comment.