Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/backtest multiple series #1517

Merged
merged 2 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,11 +1052,10 @@ def backtest(
errors = errors[0]
backtest_list.append(errors)
else:

errors = [
[metric_f(series, f) for metric_f in metric]
[metric_f(target_ts, f) for metric_f in metric]
if len(metric) > 1
else metric[0](series, f)
else metric[0](target_ts, f)
for f in forecasts[idx]
]

Expand Down
29 changes: 28 additions & 1 deletion darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
import pytest

from darts import TimeSeries
from darts.datasets import AirPassengersDataset, MonthlyMilkDataset
from darts.logging import get_logger
from darts.metrics import mape, r2_score
from darts.models import ARIMA, FFT, ExponentialSmoothing, NaiveDrift, Theta
from darts.models import (
ARIMA,
FFT,
ExponentialSmoothing,
NaiveDrift,
NaiveSeasonal,
Theta,
)
from darts.tests.base_test_class import DartsBaseTestClass
from darts.utils.timeseries_generation import gaussian_timeseries as gt
from darts.utils.timeseries_generation import linear_timeseries as lt
Expand Down Expand Up @@ -257,6 +265,25 @@ def test_backtest_forecasting(self):
self.assertEqual(pred.width, 2)
self.assertEqual(pred.end_time(), linear_series.end_time())

def test_backtest_multiple_series(self):
series = [AirPassengersDataset().load(), MonthlyMilkDataset().load()]
model = NaiveSeasonal(K=1)

error = model.backtest(
series,
train_length=30,
forecast_horizon=2,
stride=1,
retrain=True,
last_points_only=False,
verbose=False,
)

expected = [11.63104, 6.09458]
self.assertEqual(len(error), 2)
self.assertAlmostEqual(error[0], expected[0], places=4)
self.assertAlmostEqual(error[1], expected[1], places=4)

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_backtest_regression(self):
np.random.seed(4)
Expand Down