Skip to content

Commit

Permalink
Integrate performance metrics to multi prophet model
Browse files Browse the repository at this point in the history
  • Loading branch information
vonum committed May 21, 2021
1 parent b2299fd commit de7ae9c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions multi_prophet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def cross_validation(self, horizon, **kwargs):
for column, model in self.model_pool.items()
}

def performance_metrics(self, horizon, **kwargs):
return {
column: model.performance_metrics(horizon=horizon, **kwargs)
for column, model in self.model_pool.items()
}

def _init_model_pool(self, columns, **kwargs):
return {c: Prophet(**kwargs) for c in columns}

Expand Down
14 changes: 14 additions & 0 deletions tests/test_multi_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,17 @@ def test_cross_validation(self):
self.assertTrue("yhat_upper" in cross_validation_df.columns)
self.assertTrue("y" in cross_validation_df.columns)
self.assertTrue("cutoff" in cross_validation_df.columns)

def test_performance_metrics(self):
mp = multi_prophet.MultiProphet(columns=PREDICTOR_COLUMNS)
mp.fit(self.df)

performance_metrics_dfs = mp.performance_metrics(horizon="365 days")
for c, performance_metrics_df in performance_metrics_dfs.items():
self.assertTrue("horizon" in performance_metrics_df.columns)
self.assertTrue("mse" in performance_metrics_df.columns)
self.assertTrue("rmse" in performance_metrics_df.columns)
self.assertTrue("mae" in performance_metrics_df.columns)
self.assertTrue("mape" in performance_metrics_df.columns)
self.assertTrue("mdape" in performance_metrics_df.columns)
self.assertTrue("coverage" in performance_metrics_df.columns)

0 comments on commit de7ae9c

Please sign in to comment.