Skip to content

Commit

Permalink
feature/multivariate - step 4 (#111)
Browse files Browse the repository at this point in the history
- Adds multivariate functionality to backtesting functions.
- Further improves backtest_regression by introducing new parameters.
- Improved RNN notebook by using new backtesting functionality.

Co-authored-by: pennfranc <flaessig#student.ethz.ch>
  • Loading branch information
pennfranc committed Jul 7, 2020
1 parent 22702ab commit c6885ca
Show file tree
Hide file tree
Showing 7 changed files with 606 additions and 204 deletions.
104 changes: 85 additions & 19 deletions darts/backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
---------------------
"""

from typing import Iterable, Optional, Callable
from typing import Iterable, Optional, Callable, List
from itertools import product
import math
import time
Expand All @@ -13,7 +13,8 @@
import matplotlib.pyplot as plt

from ..timeseries import TimeSeries
from ..models.forecasting_model import ForecastingModel
from ..models.forecasting_model import ForecastingModel, UnivariateForecastingModel
from ..models.torch_forecasting_model import TorchForecastingModel
from ..models.regression_model import RegressionModel
from ..models import NaiveSeasonal, AutoARIMA, ExponentialSmoothing, FFT, Prophet, Theta
from .. import metrics
Expand All @@ -28,10 +29,28 @@

# TODO parameterize the moving window

def _create_parameter_dicts(model, target_indices, component_index, use_full_output_length):
fit_kwargs = {}
predict_kwargs = {}
if isinstance(model, UnivariateForecastingModel):
fit_kwargs['component_index'] = component_index
else:
fit_kwargs['target_indices'] = target_indices
if isinstance(model, TorchForecastingModel):
predict_kwargs['use_full_output_length'] = use_full_output_length

return fit_kwargs, predict_kwargs


def backtest_forecasting(series: TimeSeries,
model: ForecastingModel,
start: pd.Timestamp,
fcast_horizon_n: int,
target_indices: Optional[List[int]] = None,
component_index: Optional[int] = None,
use_full_output_length: bool = True,
stride: int = 1,
retrain: bool = True,
trim_to_series: bool = True,
verbose: bool = False) -> TimeSeries:
""" A function for backtesting `ForecastingModel`'s.
Expand All @@ -45,9 +64,14 @@ def backtest_forecasting(series: TimeSeries,
forecast horizon, and then moves the end of the training set forward by one
time step. The resulting predictions are then returned.
This always re-trains the models on the entire available history,
Unless `retrain` is set to False, this always re-trains the models on the entire available history,
corresponding an expending window strategy.
If `retrain` is set to False (useful for models with many parameter such as `TorchForecastingModel` instances),
the model will only be trained only on the initial training window (up to `start` time stamp),
and only if it has not been trained before. Then, at every iteration, the newly expanded 'training sequence'
will be fed to the model to produce the new output.
Parameters
----------
series
Expand All @@ -58,6 +82,20 @@ def backtest_forecasting(series: TimeSeries,
The first prediction time, at which a prediction is computed for a future time
fcast_horizon_n
The forecast horizon for the point predictions
target_indices
In case `series` is multivariate and `model` is a subclass of `MultivariateForecastingModel`,
a list of indices of components of `series` to be predicted by `model`.
component_index
In case `series` is multivariate and `model` is a subclass of `UnivariateForecastingModel`,
an integer index of the component of `series` to be predicted by `model`.
use_full_output_length
In case `model` is a subclass of `TorchForecastingModel`, this argument will be passed along
as argument to the predict method of `model`.
stride
The number of time steps (the unit being the frequency of `series`) between two consecutive predictions.
retrain
Whether to retrain the model for every prediction or not. Currently only `TorchForecastingModel`
instances as `model` argument support setting `retrain` to `False`.
trim_to_series
Whether the predicted series has the end trimmed to match the end of the main series
verbose
Expand All @@ -70,30 +108,42 @@ def backtest_forecasting(series: TimeSeries,
the specified model with the specified forecast horizon.
"""

series._assert_univariate()
raise_if_not(start in series, 'The provided start timestamp is not in the time series.', logger)
raise_if_not(start != series.end_time(), 'The provided start timestamp is the last timestamp of the time series',
logger)
raise_if_not(fcast_horizon_n > 0, 'The provided forecasting horizon must be a positive integer.', logger)
raise_if_not(retrain or isinstance(model, TorchForecastingModel), "Only 'TorchForecastingModel' instances"
" support the option 'retrain=False'.", logger)

last_pred_time = series.time_index()[-fcast_horizon_n - 1] if trim_to_series else series.time_index()[-1]
last_pred_time = (
series.time_index()[-fcast_horizon_n - stride] if trim_to_series else series.time_index()[-stride - 1]
)

# specify the correct fit and predict keyword arguments for the given model
fit_kwargs, predict_kwargs = _create_parameter_dicts(model, target_indices, component_index, use_full_output_length)

# build the prediction times in advance (to be able to use tqdm)
pred_times = [start]
while pred_times[-1] <= last_pred_time:
pred_times.append(pred_times[-1] + series.freq())
pred_times.append(pred_times[-1] + series.freq() * stride)

# what we'll return
values = []
times = []

iterator = _build_tqdm_iterator(pred_times, verbose)

if ((not retrain) and (not model._fit_called)):
model.fit(series.drop_after(start), verbose=verbose, **fit_kwargs)

for pred_time in iterator:
train = series.drop_after(pred_time) # build the training series
model.fit(train)
pred = model.predict(fcast_horizon_n)
values.append(pred.univariate_values()[-1]) # store the N-th point
if (retrain):
model.fit(train, **fit_kwargs)
pred = model.predict(fcast_horizon_n, **predict_kwargs)
else:
pred = model.predict(fcast_horizon_n, input_series=train, **predict_kwargs)
values.append(pred.values()[-1]) # store the N-th point
times.append(pred.end_time()) # store the N-th timestamp
return TimeSeries.from_times_and_values(pd.DatetimeIndex(times), np.array(values))

Expand Down Expand Up @@ -144,7 +194,6 @@ def backtest_regression(feature_series: Iterable[TimeSeries],
the specified model with the specified forecast horizon.
"""

raise_if_not(target_series.width == 1, "'target_series' must be univariate.", logger)
raise_if_not(all([s.has_same_time_as(target_series) for s in feature_series]), 'All provided time series must '
'have the same time index', logger)
raise_if_not(start in target_series, 'The provided start timestamp is not in the time series.', logger)
Expand Down Expand Up @@ -174,7 +223,7 @@ def backtest_regression(feature_series: Iterable[TimeSeries],

model.fit(train_features, train_target)
pred = model.predict(val_features)
values.append(pred.univariate_values()[-1]) # store the N-th point
values.append(pred.values()[-1]) # store the N-th point
times.append(pred.end_time()) # store the N-th timestamp

return TimeSeries.from_times_and_values(pd.DatetimeIndex(times), np.array(values))
Expand All @@ -183,7 +232,7 @@ def backtest_regression(feature_series: Iterable[TimeSeries],
def forecasting_residuals(model: ForecastingModel,
series: TimeSeries,
fcast_horizon_n: int = 1,
verbose: bool = True) -> TimeSeries:
verbose: bool = False) -> TimeSeries:
""" A function for computing the residuals produced by a given model and univariate time series.
This function computes the difference between the actual observations from `series`
Expand Down Expand Up @@ -285,6 +334,9 @@ def backtest_gridsearch(model_class: type,
parameters: dict,
train_series: TimeSeries,
fcast_horizon_n: Optional[int] = None,
target_indices: Optional[List[int]] = None,
component_index: Optional[int] = None,
use_full_output_length: bool = True,
val_series: Optional[TimeSeries] = None,
num_predictions: int = 10,
metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape,
Expand Down Expand Up @@ -324,6 +376,15 @@ def backtest_gridsearch(model_class: type,
The univariate TimeSeries instance used for validation in split mode.
fcast_horizon_n
The integer value of the forecasting horizon used in expanding window mode.
target_indices
In case `series` is multivariate and `model` is a subclass of `MultivariateForecastingModel`,
a list of indices of components of `series` to be predicted by `model`.
component_index
In case `series` is multivariate and `model` is a subclass of `UnivariateForecastingModel`,
an integer index of the component of `series` to be predicted by `model`.
use_full_output_length
In case `model` is a subclass of `TorchForecastingModel`, this argument will be passed along
as argument to the predict method of `model`.
num_predictions:
The number of train/prediction cycles performed in one iteration of expanding window mode.
metric:
Expand All @@ -337,13 +398,16 @@ def backtest_gridsearch(model_class: type,
An untrained 'model_class' instance with the best-performing hyperparameters from the given selection.
"""

train_series._assert_univariate()
if (val_series is not None):
val_series._assert_univariate()
raise_if_not(train_series.width == val_series.width, "Training and validation series require the same"
" number of components.", logger)

raise_if_not((fcast_horizon_n is None) ^ (val_series is None),
"Please pass exactly one of the arguments 'forecast_horizon_n' or 'val_series'.", logger)

fit_kwargs, predict_kwargs = _create_parameter_dicts(model_class(), target_indices, component_index,
use_full_output_length)

if val_series is None:
backtest_start_time = train_series.end_time() - (num_predictions + fcast_horizon_n) * train_series.freq()
min_error = float('inf')
Expand All @@ -358,11 +422,12 @@ def backtest_gridsearch(model_class: type,
param_combination_dict = dict(list(zip(parameters.keys(), param_combination)))
model = model_class(**param_combination_dict)
if val_series is None: # expanding window mode
backtest_forecast = backtest_forecasting(train_series, model, backtest_start_time, fcast_horizon_n)
backtest_forecast = backtest_forecasting(train_series, model, backtest_start_time, fcast_horizon_n,
target_indices, component_index, use_full_output_length)
error = metric(backtest_forecast, train_series)
else: # split mode
model.fit(train_series)
error = metric(model.predict(len(val_series)), val_series)
model.fit(train_series, **fit_kwargs)
error = metric(model.predict(len(val_series)), val_series, **predict_kwargs)
if error < min_error:
min_error = error
best_param_combination = param_combination_dict
Expand All @@ -376,8 +441,9 @@ def explore_models(train_series: TimeSeries,
metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape,
model_parameter_tuples: Optional[list] = None,
plot_width: int = 3,
verbose: bool = True):
""" A function for exploring the suitability of multiple models on a given train/validation/test split.
verbose: bool = False):
""" A function for exploring the suitability of multiple models on a given train/validation/test split
of a univariate series.
This funtion iterates through a list of models, training each on `train_series` and `val_series`
and evaluating them on `test_series`. Models with free hyperparameters are first
Expand Down
19 changes: 17 additions & 2 deletions darts/models/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,10 @@ def fit(self,
tb_writer.flush()
tb_writer.close()

def predict(self, n: int, use_full_output_length: bool = False) -> TimeSeries:
def predict(self, n: int,
use_full_output_length: bool = False,
input_series: Optional[TimeSeries] = None) -> TimeSeries:

""" Predicts values for a certain number of time steps after the end of the training series
In the case of univariate training series, `n` can assume any integer value greater than 0.
Expand All @@ -353,6 +356,9 @@ def predict(self, n: int, use_full_output_length: bool = False) -> TimeSeries:
use_full_output_length
Boolean value indicating whether or not the full output sequence of the model prediction should be
used to produce the output of this function.
input_series
Optionally, the input TimeSeries instance fed to the trained TorchForecastingModel to produce the
prediction. If it is not passed, the training TimeSeries instance will be used as input.
Returns
-------
Expand All @@ -361,12 +367,21 @@ def predict(self, n: int, use_full_output_length: bool = False) -> TimeSeries:
"""
super().predict(n)

raise_if_not(input_series is None or input_series.width == self.training_series.width,
"'input_series' must have same width as series used to fit model.", logger)

raise_if_not(use_full_output_length or self.training_series.width == 1, "Please set 'use_full_output_length'"
" to 'True' and 'n' smaller or equal to 'output_length' when using a multivariate"
"TimeSeries instance as input.", logger)

# create input sequence for prediction
input_sequence = self.training_series.values()[-self.input_length:]
if input_series is None:
input_sequence = self.training_series.values()[-self.input_length:]
else:
raise_if_not(len(input_series) >= self.input_length,
"'input_series' must at least be as long as 'self.input_length'", logger)
input_sequence = input_series.values()[-self.input_length:]
super().fit(input_series, self.target_indices)
pred_in = torch.from_numpy(input_sequence).float().view(1, self.input_length, -1).to(self.device)

# iterate through predicting output and consuming it again until enough predictions are created
Expand Down
71 changes: 63 additions & 8 deletions darts/tests/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@
constant_timeseries as ct,
gaussian_timeseries as gt
)
from ..models import Theta, FFT, ExponentialSmoothing, NaiveSeasonal, StandardRegressionModel, NaiveDrift
from ..models import (
Theta,
FFT,
ExponentialSmoothing,
NaiveSeasonal,
StandardRegressionModel,
NaiveDrift,
TCNModel
)


def compare_best_against_random(model_class, params, series):
Expand Down Expand Up @@ -62,14 +70,48 @@ def setUpClass(cls):

def test_backtest_forecasting(self):
linear_series = lt(length=50)
linear_series_multi = linear_series.stack(linear_series)

# univariate series
# univariate model + univariate series
pred = backtest_forecasting(linear_series, NaiveDrift(), pd.Timestamp('20000201'), 3)
self.assertEqual(r2_score(pred, linear_series), 1.0)

# multivariate series
with self.assertRaises(AssertionError):
backtest_forecasting(linear_series.stack(linear_series), NaiveDrift(), pd.Timestamp('20000201'), 3)
# univariate model + multivariate series without component index argument
with self.assertRaises(ValueError):
backtest_forecasting(linear_series_multi, NaiveDrift(), pd.Timestamp('20000201'), 3)

# univariate model + multivariate series with component index argument
pred = backtest_forecasting(linear_series_multi, NaiveDrift(), pd.Timestamp('20000201'), 3,
component_index=0, verbose=False)
self.assertEqual(pred.width, 1)
self.assertEqual(r2_score(pred, linear_series), 1.0)
pred = backtest_forecasting(linear_series_multi, NaiveDrift(), pd.Timestamp('20000201'), 3,
component_index=1, verbose=False)
self.assertEqual(pred.width, 1)
self.assertEqual(r2_score(pred, linear_series), 1.0)

# multivariate model + univariate series
tcn_model = TCNModel(batch_size=1, n_epochs=1)
pred = backtest_forecasting(linear_series, tcn_model, pd.Timestamp('20000125'), 3, verbose=False)
self.assertEqual(pred.width, 1)

# multivariate model + multivariate series
with self.assertRaises(ValueError):
backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False)
tcn_model = TCNModel(batch_size=1, n_epochs=1, input_size=2, output_length=3)
with self.assertRaises(ValueError):
backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False,
use_full_output_length=False)
pred = backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 1, target_indices=[0],
verbose=False)
self.assertEqual(pred.width, 1)
pred = backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False,
use_full_output_length=True, target_indices=[1])
self.assertEqual(pred.width, 1)
tcn_model = TCNModel(batch_size=1, n_epochs=1, input_size=2, output_length=3, output_size=2)
pred = backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False,
use_full_output_length=True, target_indices=[0, 1])
self.assertEqual(pred.width, 2)

def test_backtest_regression(self):
gaussian_series = gt(mean=2, length=50)
Expand All @@ -88,9 +130,9 @@ def test_backtest_regression(self):
self.assertEqual(r2_score(pred, target), 1.0)

# multivariate target
with self.assertRaises(ValueError):
backtest_regression(features_multivariate, target.stack(target),
StandardRegressionModel(15), pd.Timestamp('20000201'), 3)
pred = backtest_regression(features_multivariate, target.stack(target),
StandardRegressionModel(15), pd.Timestamp('20000201'), 3)
self.assertEqual(r2_score(pred, target.stack(target)), 1.0)

def test_backtest_gridsearch(self):

Expand All @@ -109,6 +151,19 @@ def test_backtest_gridsearch(self):
es_params = {'seasonal_periods': list(range(5, 10))}
self.assertTrue(compare_best_against_random(ExponentialSmoothing, es_params, dummy_series))

def test_backtest_gridsearch_multi(self):
dummy_series = st(length=40, value_y_offset=10).stack(lt(length=40, end_value=20))
tcn_params = {
'n_epochs': [1],
'batch_size': [1],
'input_size': [2],
'output_length': [3],
'output_size': [2],
'kernel_size': [2, 3, 4]
}
backtest_gridsearch(TCNModel, tcn_params, dummy_series, fcast_horizon_n=3, metric=mape,
use_full_output_length=True, target_indices=[0, 1])

def test_forecasting_residuals(self):
model = NaiveSeasonal(K=1)

Expand Down

0 comments on commit c6885ca

Please sign in to comment.