Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/multivariate - step 4 #111

Merged
merged 185 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from 184 commits
Commits
Show all changes
185 commits
Select commit Hold shift + click to select a range
a1ea75b
feature(tcn): started with implementation
May 11, 2020
7ccc04e
feature(tcn): merged with new develop branch, fixed conflicts
May 11, 2020
ec0c496
feature(torch forecasting): created TorchForecastingModel class and m…
May 11, 2020
65f3ceb
feature(torch forecasting): updated docstrings
May 11, 2020
d046bdd
feature(torch forecasting): polished arguments, fixed test case
May 11, 2020
13b9afe
feature(tcn): changes to torch module
May 12, 2020
f66e7be
merged with current develop
May 14, 2020
b4a3322
feature(tcn): implemented residual block architecture, variable outpu…
May 14, 2020
4116ef3
feature(tcn): added weight normalization option
May 14, 2020
fee47e8
feature(tcn): updated docstrings
May 15, 2020
77568b6
feature(examples): updated FFT notebook
May 15, 2020
9aa93ff
merged with develop
May 15, 2020
d397e5b
feature(examples): updated FFT notebook
May 15, 2020
500a38c
feature(torch forecasting): added docstring for _TimeSeriesDataset1DS…
May 15, 2020
5fbb44c
fix(formatting): solved linter issues
May 15, 2020
e91a508
fix(testing): fixed import issue
May 15, 2020
844017c
fix(formatting): solved more linter issues
May 15, 2020
03560e2
feature(torch forecasting): updated documentation
May 15, 2020
80cc905
feature(torch forecasting): improved docstrings
May 15, 2020
b71571f
fix(formatting): solved sphinx issue
May 15, 2020
f2838b2
feature(examples): updated RNN example
May 15, 2020
6580bd2
fix(tcn): fixed variable names
May 25, 2020
0bcb174
fix(tcn): fixed variable names
May 25, 2020
27c599e
feature(testing): added TCN tests
May 26, 2020
4e1cdd2
fix(tcn): removed if statement parantheses
May 26, 2020
b7a5797
fix(tcn): fixed edge case when num_layers=1, removed last ReLU
May 26, 2020
2db6474
merged with develop
May 26, 2020
776a439
fix(testing): fixed TCN fit test
May 26, 2020
1dfdea8
fix(formatting): solved linter issues
May 26, 2020
8da1df5
fix(tcn_model): fixed automatic number of layer computation
May 28, 2020
7648a51
feature(TCN): added receptive field coverage testing
May 28, 2020
5d5f1c8
fix(TCN): minor formatting changes
May 28, 2020
8212bba
fix(TCN): fixed typo
May 28, 2020
336758c
fix(tcn): fixed output_length being ignored
May 29, 2020
2546145
feature(TorchForecastingModel): added first_prediction_index property
May 29, 2020
7247562
fix(formatting): solved linter issues
May 29, 2020
2ccf35f
feature(TorchForecastingModel): added new mode to predict that uses a…
May 29, 2020
81b6018
feature(backtesting): added no-retrain option and real_time_plot func…
May 29, 2020
224df58
feature(backtesting) added docstring descriptions of new backtest_for…
May 29, 2020
1681591
merged with develop
Jun 2, 2020
7ad1d01
feature(backtesting): plotting only relevant part in real_time_plot, …
Jun 3, 2020
67ead17
feature(univariate): started rewriting the TimeSeries class to be bas…
Jun 3, 2020
e71b1b8
feature(multivariate): first working version of multivariate TimeSeri…
Jun 4, 2020
a4937cf
feature(multivariate): first working version of multivariate TimeSeri…
Jun 4, 2020
2a2603d
feature(multivariate): adapted all existing code to use univariate in…
Jun 4, 2020
25f0545
fix(fft): fixed condition when cropping training set
Jun 4, 2020
a985cb6
fix(TimeSeries): fixed plot method
Jun 4, 2020
c01bea7
feature(multivariate): adapted TimeSeries.update and corresponding un…
Jun 5, 2020
ef00405
fix(formatting): fixed linter issues
Jun 5, 2020
f97a3f2
fix(ForecastingModel): fixed typo
Jun 5, 2020
04aa1c5
feature(formatting): fixed final linter problem
Jun 5, 2020
a26eac1
fix(TimeSeries): fixed a couple of docstrings
Jun 5, 2020
b3445a4
feature(tests): added separate multivariate test file for TimeSeries …
Jun 5, 2020
2152306
feature(multivariate): added TimeSeries.stack
Jun 5, 2020
df81f12
feature(multivariate): added TimeSeries.width property and TimeSeries…
Jun 5, 2020
f3006c0
fix(formatting): solved linter issues
Jun 5, 2020
5713095
feature(TimeSeries): added TimeSeries.add_datetime_attribute function
Jun 8, 2020
6e6fcc0
feature(timeseries_generation): updated docstrings
Jun 8, 2020
b3ff9c5
feature(TimeSeries): added TimeSeries.add_holidays by readapting time…
Jun 8, 2020
fa2674d
fix(formatting): solved linter issues
Jun 8, 2020
1677001
feature(multivariate): updated TimeSeries.from_dataframe to accept a …
Jun 8, 2020
342610a
fix(TimesSeries): fixed docstring
Jun 8, 2020
9e71d84
feature(multivariate): adapted tests to be reused in multivariate con…
Jun 8, 2020
fe98f40
fix(TimeSeries): fixed TimeSeries.update
Jun 8, 2020
9b2eef5
fix(TimeSeries): fixed variable name typo
Jun 8, 2020
7ac40fa
fix(formatting): solved linter issues
Jun 8, 2020
d7ac73c
fix(formatting): solved linter issues
Jun 8, 2020
7d1c8e1
feature(multivariate): added multivariate support to torch models, ad…
Jun 8, 2020
0fe7a39
feature(multivariate): added multivariate input support for RNN
Jun 9, 2020
54dc266
feature(multivariate): added multivariate support to _TimeSeriesDatas…
Jun 9, 2020
a61bfc9
feature(timeseries_generation): moved main add_datetime_attribute and…
Jun 9, 2020
6981403
fix(timeseries_generation): fixed mistakes
Jun 9, 2020
1ae7dc8
fix(formatting): solved linter issues
Jun 9, 2020
d3043ff
merged with multivariate-2
Jun 9, 2020
a8b6d7a
feature(TorchForecastingModel): implemented 'datetime_enhancement' an…
Jun 9, 2020
130a8c0
fix(formatting): solved linter issues
Jun 9, 2020
86f7948
feature(multivariate): requiring dataframe in main TimeSeries constru…
Jun 10, 2020
49dd263
fix(formatting): solved linter issues
Jun 10, 2020
517a987
feature(multivariate): added automatic input_size adjustment when pas…
Jun 10, 2020
cbf5e3b
feature(multivariate): added multivariate support to predict in use_f…
Jun 10, 2020
78378d0
feature(examples): added simple example showcasing datetime enhanceme…
Jun 10, 2020
ac1ab6e
feature(logging): added raise_if function
Jun 11, 2020
738a316
feature(multivariate): changed attribute name _series to _df
Jun 11, 2020
9fbfe11
feature(multivariate): made time series naming more consistent, remov…
Jun 11, 2020
894b8bb
fix(testing): adapted test to new logging message
Jun 11, 2020
be63e27
merged with feature/multivariate (multivariate-1)
Jun 11, 2020
e2fc0db
feature(testing): added test case for TimeSeries._assert_multivariate()
Jun 11, 2020
b75d80f
feature(multivariate): added TimeSeries.first_values() and TimeSeries…
Jun 11, 2020
0fb18f0
fix(formatting): fixed linter issue
Jun 11, 2020
86aaa3b
feature(multivariate): added width check to TimeSeries.plot()
Jun 11, 2020
cf9e2d0
merged with multivariate-2
Jun 11, 2020
4a9876f
feature(multivariate): renamed notebook
Jun 12, 2020
422f1a8
feature(multivariate): adapted scaler wrapper to multivariate
Jun 12, 2020
be78170
fix(timeseries_generation): fixed one-hot encoding of datetime_attrib…
Jun 12, 2020
756aebb
feature(examples): created multivariate functionality notebook
Jun 12, 2020
fe771b4
merged with multivariate-2
Jun 12, 2020
bf2f75b
feature(multivariate): added multivariate support for regression feat…
Jun 15, 2020
14385f0
feature(multivariate): added multivariate target support to StandardR…
Jun 15, 2020
a92ec7a
fix(formatting): solved linter issues
Jun 15, 2020
f559426
feature(multivariate): introduced univariate and multivariate subclas…
Jun 15, 2020
29952c9
fix(formatting): solved linter issue
Jun 15, 2020
72ed3eb
merged with multivariate-2
Jun 15, 2020
7d9cbb7
feature(multivariate): removed 'enhancement' arguments for torch mode…
Jun 15, 2020
41fada9
fix(TorchForecastingModel): fixed bug
Jun 15, 2020
e538d92
feature(multivariate): added metrics support for multivariate time se…
Jun 15, 2020
6668bc4
feature(metrics): one more tweak for multivariate support, more tests
Jun 15, 2020
0eaffbf
fix(testing): fixed metrics test case
Jun 15, 2020
663b7d7
fix(formatting): fixed linter issue
Jun 15, 2020
67d209c
merged with multivariate-2
Jun 15, 2020
88f1ef8
feature(multivariate): added multivariate target sequence support to rnn
Jun 15, 2020
59d20dc
fix(formatting): fixed linter
Jun 15, 2020
e658ab2
feature(TorchForecastingModel): improved value checks
Jun 16, 2020
6c45ad2
feature(testing): added multivariate tests for TCN and RNN
Jun 16, 2020
cdb9cbf
feature(multivariate): added univariate checks to backtesting
Jun 16, 2020
e8e2232
feature(testing): added more backtesting tests
Jun 16, 2020
d810d9d
fix(TorchForecastingModel): fixed predict function in use_full_output…
Jun 16, 2020
8d1b638
feature(multivariate): added short example notebook
Jun 16, 2020
033eea4
feature(docs): added multivariate-example to examples.rst
Jun 16, 2020
ac47b52
fix(formatting): solved linter issues
Jun 16, 2020
15a0293
merged with develop
Jun 16, 2020
74f3957
merged with feature/multivariate
Jun 16, 2020
7e42775
merged with multivariate-2
Jun 16, 2020
9d14112
feature(metrics): added width check for every metric function
Jun 17, 2020
b8d5f28
fix(examples): updated imports of multivariate examples
Jun 17, 2020
8c9cbf0
merged with develop
Jun 17, 2020
87dfb7c
merged with feature/multivariate
Jun 17, 2020
a547ff4
merged with multivariate-2
Jun 17, 2020
667285a
feature(tests): added test file and test case for utils.py
Jun 18, 2020
0ac4103
feature(multivariate): added checks that input is univariate in stati…
Jun 18, 2020
2511083
fix(formatting): linter
Jun 18, 2020
0997bdf
feature(backtesting): added more backtesting options, including multi…
Jun 18, 2020
5e6299b
feature(examples): started changing TCN example
Jun 18, 2020
78f965d
merged with develop
Jun 18, 2020
4bc025c
merged with feature/multivariate
Jun 18, 2020
4fac6a4
merged with multivariate-2
Jun 18, 2020
a97d7fa
fix(TCN): forgot to include a change
Jun 18, 2020
effdecd
fix(examples): fixed notebooks to work with new multivariate function…
Jun 18, 2020
8198b56
merged with multivariate-3
Jun 19, 2020
cbcebe4
Merge branch 'develop' into feature/multivariate
pennfranc Jun 22, 2020
b9a5bcb
feature(multivariate): improved backtesting, added more tests
Jun 22, 2020
1e5c4d8
merged with develop
Jun 22, 2020
7eb54e5
merged with feature/multivariate
Jun 22, 2020
904e404
merged with multivariate-2
Jun 22, 2020
148566c
merged with multivariate-3
Jun 22, 2020
749b3d2
feature(multivariate): added more backtesting tests
Jun 22, 2020
648aeeb
merged with develop
Jun 22, 2020
58b688a
merged with feature/multivariate
Jun 22, 2020
9cbae0c
merged with multivariate-2
Jun 22, 2020
698eb0c
fix(multivariate): forgot to pass on frequency to new time series obt…
Jun 22, 2020
9c238f6
merged with multivariate-
Jun 22, 2020
c0c3955
feature(multivariate): for MultivariateForecastingModel.fit added che…
Jun 22, 2020
d05d94b
fix(formatting): fixed linter issue
Jun 22, 2020
fb9f944
feature(multivariate): added multivariate support for missing_values …
Jun 22, 2020
917408d
merged with multivariate-3
Jun 22, 2020
8c80ad8
feature(backtesting): updated docstring of backtest_forecasting
Jun 22, 2020
05c9abc
feature(multivariate): added more tests for backtest_forecasting
Jun 22, 2020
74d3081
fix(formatting): linter
Jun 22, 2020
7da6efc
feature(multivariate): added multivariate support to backtest_regression
Jun 22, 2020
4a1d9d4
feature(multivariate): added multivariate support to backtest_gridsearch
Jun 22, 2020
199b50b
fix(formatting): linter
Jun 22, 2020
6dacd70
merged with develop
Jun 23, 2020
2cb6d53
merged with feature/multivariate
Jun 23, 2020
4f54a8d
merged with multivariate-2
Jun 23, 2020
ec6f01c
merged with multivariate-3
Jun 23, 2020
a3f1025
merged with develop
Jun 25, 2020
03a5157
fix(TorchForecastingModel, TCNModel, backtesting): fixed some issues …
Jun 25, 2020
57029af
fix(backtesting): removed mutable default arguments
Jun 25, 2020
9d629af
fix(TorchForecastingModel): fixed predict function
Jun 25, 2020
1722873
feature(examples): updated TCN notebook
Jun 25, 2020
81ab93f
fix(backtesting): fixed indexing error relating to stride
Jun 25, 2020
57bcc6e
fix(TimeSeries): changed constructor not to throw an error when given…
Jun 25, 2020
8abbf6c
feature(examples): updated TCN notebook
Jun 25, 2020
3e5931b
fix(TimeSeries): replaced deperecated warn method
Jun 25, 2020
f1d6e54
fix(tests): fixed TimeSeries test to reflect new frequency mismatch h…
Jun 25, 2020
bab99e0
fix(tests): fixed backtesting test case
Jun 25, 2020
9ecbe30
feature(examples): updated RNN notebook
Jun 25, 2020
044840b
fix(formatting): linter
Jun 25, 2020
d4ea330
fix(logging, TorchForecastingModel): fixed merging mistakes
Jun 26, 2020
5143028
feature(backtesting): improved docstrings
Jun 26, 2020
b2d99a2
merged with develop
Jun 26, 2020
332415d
fix(backtesting): added backticks to docstring
Jul 3, 2020
c6141ef
fix(backtesting): moved argument check up in backtest_forecasting
Jul 3, 2020
6274acc
fix(examples): reverted FFT examples execution counts
Jul 3, 2020
5804e41
fix(examples): fixed missing FFT notebook change
Jul 3, 2020
22526ad
merged with develop
Jul 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 85 additions & 19 deletions darts/backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
---------------------
"""

from typing import Iterable, Optional, Callable
from typing import Iterable, Optional, Callable, List
from itertools import product
import math
import time
Expand All @@ -13,7 +13,8 @@
import matplotlib.pyplot as plt

from ..timeseries import TimeSeries
from ..models.forecasting_model import ForecastingModel
from ..models.forecasting_model import ForecastingModel, UnivariateForecastingModel
from ..models.torch_forecasting_model import TorchForecastingModel
from ..models.regression_model import RegressionModel
from ..models import NaiveSeasonal, AutoARIMA, ExponentialSmoothing, FFT, Prophet, Theta
from .. import metrics
Expand All @@ -28,10 +29,28 @@

# TODO parameterize the moving window

def _create_parameter_dicts(model, target_indices, component_index, use_full_output_length):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is what @guillaumeraille was asking on the daily in the morning.

@pennfranc this looks good, think we can also somehow make it automatic?

fit_kwargs = {}
predict_kwargs = {}
if isinstance(model, UnivariateForecastingModel):
fit_kwargs['component_index'] = component_index
else:
fit_kwargs['target_indices'] = target_indices
if isinstance(model, TorchForecastingModel):
predict_kwargs['use_full_output_length'] = use_full_output_length

return fit_kwargs, predict_kwargs


def backtest_forecasting(series: TimeSeries,
model: ForecastingModel,
start: pd.Timestamp,
fcast_horizon_n: int,
target_indices: Optional[List[int]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a lot of args passed to this function. As an optional improvement can consider moving them to a class level and wrap the whole backtesting module in a class.

component_index: Optional[int] = None,
use_full_output_length: bool = True,
stride: int = 1,
retrain: bool = True,
trim_to_series: bool = True,
verbose: bool = False) -> TimeSeries:
""" A function for backtesting `ForecastingModel`'s.
Expand All @@ -45,9 +64,14 @@ def backtest_forecasting(series: TimeSeries,
forecast horizon, and then moves the end of the training set forward by one
time step. The resulting predictions are then returned.

This always re-trains the models on the entire available history,
Unless `retrain` is set to False, this always re-trains the models on the entire available history,
corresponding an expending window strategy.

If `retrain` is set to False (useful for models with many parameter such as `TorchForecastingModel` instances),
the model will only be trained only on the initial training window (up to `start` time stamp),
and only if it has not been trained before. Then, at every iteration, the newly expanded 'training sequence'
will be fed to the model to produce the new output.

Parameters
----------
series
Expand All @@ -58,6 +82,20 @@ def backtest_forecasting(series: TimeSeries,
The first prediction time, at which a prediction is computed for a future time
fcast_horizon_n
The forecast horizon for the point predictions
target_indices
In case `series` is multivariate and `model` is a subclass of `MultivariateForecastingModel`,
a list of indices of components of `series` to be predicted by `model`.
component_index
In case `series` is multivariate and `model` is a subclass of `UnivariateForecastingModel`,
an integer index of the component of `series` to be predicted by `model`.
use_full_output_length
In case `model` is a subclass of `TorchForecastingModel`, this argument will be passed along
as argument to the predict method of `model`.
stride
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice. It would also be nice to have the option of configuring the training set length. When set, this would do moving window, and when not set it would do expending window (can wait a future PR though :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for sure, @guillaumeraille also suggested this!

The number of time steps (the unit being the frequency of `series`) between two consecutive predictions.
retrain
Whether to retrain the model for every prediction or not. Currently only `TorchForecastingModel`
instances as `model` argument support setting `retrain` to `False`.
trim_to_series
Whether the predicted series has the end trimmed to match the end of the main series
verbose
Expand All @@ -70,30 +108,42 @@ def backtest_forecasting(series: TimeSeries,
the specified model with the specified forecast horizon.
"""

series._assert_univariate()
raise_if_not(start in series, 'The provided start timestamp is not in the time series.', logger)
raise_if_not(start != series.end_time(), 'The provided start timestamp is the last timestamp of the time series',
logger)
raise_if_not(fcast_horizon_n > 0, 'The provided forecasting horizon must be a positive integer.', logger)
raise_if_not(retrain or isinstance(model, TorchForecastingModel), "Only 'TorchForecastingModel' instances"
" support the option 'retrain=False'.", logger)

last_pred_time = series.time_index()[-fcast_horizon_n - 1] if trim_to_series else series.time_index()[-1]
last_pred_time = (
series.time_index()[-fcast_horizon_n - stride] if trim_to_series else series.time_index()[-stride - 1]
)

# specify the correct fit and predict keyword arguments for the given model
fit_kwargs, predict_kwargs = _create_parameter_dicts(model, target_indices, component_index, use_full_output_length)

# build the prediction times in advance (to be able to use tqdm)
pred_times = [start]
while pred_times[-1] <= last_pred_time:
pred_times.append(pred_times[-1] + series.freq())
pred_times.append(pred_times[-1] + series.freq() * stride)

# what we'll return
values = []
times = []

iterator = _build_tqdm_iterator(pred_times, verbose)

if ((not retrain) and (not model._fit_called)):
model.fit(series.drop_after(start), verbose=verbose, **fit_kwargs)

for pred_time in iterator:
train = series.drop_after(pred_time) # build the training series
model.fit(train)
pred = model.predict(fcast_horizon_n)
values.append(pred.univariate_values()[-1]) # store the N-th point
if (retrain):
model.fit(train, **fit_kwargs)
pred = model.predict(fcast_horizon_n, **predict_kwargs)
else:
pred = model.predict(fcast_horizon_n, input_series=train, **predict_kwargs)
values.append(pred.values()[-1]) # store the N-th point
times.append(pred.end_time()) # store the N-th timestamp
return TimeSeries.from_times_and_values(pd.DatetimeIndex(times), np.array(values))

Expand Down Expand Up @@ -144,7 +194,6 @@ def backtest_regression(feature_series: Iterable[TimeSeries],
the specified model with the specified forecast horizon.
"""

raise_if_not(target_series.width == 1, "'target_series' must be univariate.", logger)
raise_if_not(all([s.has_same_time_as(target_series) for s in feature_series]), 'All provided time series must '
'have the same time index', logger)
raise_if_not(start in target_series, 'The provided start timestamp is not in the time series.', logger)
Expand Down Expand Up @@ -174,7 +223,7 @@ def backtest_regression(feature_series: Iterable[TimeSeries],

model.fit(train_features, train_target)
pred = model.predict(val_features)
values.append(pred.univariate_values()[-1]) # store the N-th point
values.append(pred.values()[-1]) # store the N-th point
times.append(pred.end_time()) # store the N-th timestamp

return TimeSeries.from_times_and_values(pd.DatetimeIndex(times), np.array(values))
Expand All @@ -183,7 +232,7 @@ def backtest_regression(feature_series: Iterable[TimeSeries],
def forecasting_residuals(model: ForecastingModel,
series: TimeSeries,
fcast_horizon_n: int = 1,
verbose: bool = True) -> TimeSeries:
verbose: bool = False) -> TimeSeries:
""" A function for computing the residuals produced by a given model and univariate time series.

This function computes the difference between the actual observations from `series`
Expand Down Expand Up @@ -285,6 +334,9 @@ def backtest_gridsearch(model_class: type,
parameters: dict,
train_series: TimeSeries,
fcast_horizon_n: Optional[int] = None,
target_indices: Optional[List[int]] = None,
component_index: Optional[int] = None,
use_full_output_length: bool = True,
val_series: Optional[TimeSeries] = None,
num_predictions: int = 10,
metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape,
Expand Down Expand Up @@ -324,6 +376,15 @@ def backtest_gridsearch(model_class: type,
The univariate TimeSeries instance used for validation in split mode.
fcast_horizon_n
The integer value of the forecasting horizon used in expanding window mode.
target_indices
In case `series` is multivariate and `model` is a subclass of `MultivariateForecastingModel`,
a list of indices of components of `series` to be predicted by `model`.
component_index
In case `series` is multivariate and `model` is a subclass of `UnivariateForecastingModel`,
an integer index of the component of `series` to be predicted by `model`.
use_full_output_length
In case `model` is a subclass of `TorchForecastingModel`, this argument will be passed along
as argument to the predict method of `model`.
num_predictions:
The number of train/prediction cycles performed in one iteration of expanding window mode.
metric:
Expand All @@ -337,13 +398,16 @@ def backtest_gridsearch(model_class: type,
An untrained 'model_class' instance with the best-performing hyperparameters from the given selection.
"""

train_series._assert_univariate()
if (val_series is not None):
val_series._assert_univariate()
raise_if_not(train_series.width == val_series.width, "Training and validation series require the same"
" number of components.", logger)

raise_if_not((fcast_horizon_n is None) ^ (val_series is None),
"Please pass exactly one of the arguments 'forecast_horizon_n' or 'val_series'.", logger)

fit_kwargs, predict_kwargs = _create_parameter_dicts(model_class(), target_indices, component_index,
use_full_output_length)

if val_series is None:
backtest_start_time = train_series.end_time() - (num_predictions + fcast_horizon_n) * train_series.freq()
min_error = float('inf')
Expand All @@ -358,11 +422,12 @@ def backtest_gridsearch(model_class: type,
param_combination_dict = dict(list(zip(parameters.keys(), param_combination)))
model = model_class(**param_combination_dict)
if val_series is None: # expanding window mode
backtest_forecast = backtest_forecasting(train_series, model, backtest_start_time, fcast_horizon_n)
backtest_forecast = backtest_forecasting(train_series, model, backtest_start_time, fcast_horizon_n,
target_indices, component_index, use_full_output_length)
error = metric(backtest_forecast, train_series)
else: # split mode
model.fit(train_series)
error = metric(model.predict(len(val_series)), val_series)
model.fit(train_series, **fit_kwargs)
error = metric(model.predict(len(val_series)), val_series, **predict_kwargs)
if error < min_error:
min_error = error
best_param_combination = param_combination_dict
Expand All @@ -376,8 +441,9 @@ def explore_models(train_series: TimeSeries,
metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape,
model_parameter_tuples: Optional[list] = None,
plot_width: int = 3,
verbose: bool = True):
""" A function for exploring the suitability of multiple models on a given train/validation/test split.
verbose: bool = False):
""" A function for exploring the suitability of multiple models on a given train/validation/test split
of a univariate series.

This funtion iterates through a list of models, training each on `train_series` and `val_series`
and evaluating them on `test_series`. Models with free hyperparameters are first
Expand Down
19 changes: 17 additions & 2 deletions darts/models/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ def fit(self,
tb_writer.flush()
tb_writer.close()

def predict(self, n: int, use_full_output_length: bool = False) -> TimeSeries:
def predict(self, n: int,
use_full_output_length: bool = False,
input_series: Optional[TimeSeries] = None) -> TimeSeries:

""" Predicts values for a certain number of time steps after the end of the training series

In the case of univariate training series, `n` can assume any integer value greater than 0.
Expand All @@ -352,6 +355,9 @@ def predict(self, n: int, use_full_output_length: bool = False) -> TimeSeries:
use_full_output_length
Boolean value indicating whether or not the full output sequence of the model prediction should be
used to produce the output of this function.
input_series
Optionally, the input TimeSeries instance fed to the trained TorchForecastingModel to produce the
prediction. If it is not passed, the training TimeSeries instance will be used as input.

Returns
-------
Expand All @@ -360,12 +366,21 @@ def predict(self, n: int, use_full_output_length: bool = False) -> TimeSeries:
"""
super().predict(n)

raise_if_not(input_series is None or input_series.width == self.training_series.width,
"'input_series' must have same width as series used to fit model.", logger)

raise_if_not(use_full_output_length or self.training_series.width == 1, "Please set 'use_full_output_length'"
" to 'True' and 'n' smaller or equal to 'output_length' when using a multivariate"
"TimeSeries instance as input.", logger)

# create input sequence for prediction
input_sequence = self.training_series.values()[-self.input_length:]
if input_series is None:
input_sequence = self.training_series.values()[-self.input_length:]
else:
raise_if_not(len(input_series) >= self.input_length,
"'input_series' must at least be as long as 'self.input_length'", logger)
input_sequence = input_series.values()[-self.input_length:]
super().fit(input_series, self.target_indices)
pred_in = torch.from_numpy(input_sequence).float().view(1, self.input_length, -1).to(self.device)

# iterate through predicting output and consuming it again until enough predictions are created
Expand Down
71 changes: 63 additions & 8 deletions darts/tests/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@
constant_timeseries as ct,
gaussian_timeseries as gt
)
from ..models import Theta, FFT, ExponentialSmoothing, NaiveSeasonal, StandardRegressionModel, NaiveDrift
from ..models import (
Theta,
FFT,
ExponentialSmoothing,
NaiveSeasonal,
StandardRegressionModel,
NaiveDrift,
TCNModel
)


def compare_best_against_random(model_class, params, series):
Expand Down Expand Up @@ -62,14 +70,48 @@ def setUpClass(cls):

def test_backtest_forecasting(self):
linear_series = lt(length=50)
linear_series_multi = linear_series.stack(linear_series)

# univariate series
# univariate model + univariate series
pred = backtest_forecasting(linear_series, NaiveDrift(), pd.Timestamp('20000201'), 3)
self.assertEqual(r2_score(pred, linear_series), 1.0)

# multivariate series
with self.assertRaises(AssertionError):
backtest_forecasting(linear_series.stack(linear_series), NaiveDrift(), pd.Timestamp('20000201'), 3)
# univariate model + multivariate series without component index argument
with self.assertRaises(ValueError):
backtest_forecasting(linear_series_multi, NaiveDrift(), pd.Timestamp('20000201'), 3)

# univariate model + multivariate series with component index argument
pred = backtest_forecasting(linear_series_multi, NaiveDrift(), pd.Timestamp('20000201'), 3,
component_index=0, verbose=False)
self.assertEqual(pred.width, 1)
self.assertEqual(r2_score(pred, linear_series), 1.0)
pred = backtest_forecasting(linear_series_multi, NaiveDrift(), pd.Timestamp('20000201'), 3,
component_index=1, verbose=False)
self.assertEqual(pred.width, 1)
self.assertEqual(r2_score(pred, linear_series), 1.0)

# multivariate model + univariate series
tcn_model = TCNModel(batch_size=1, n_epochs=1)
pred = backtest_forecasting(linear_series, tcn_model, pd.Timestamp('20000125'), 3, verbose=False)
self.assertEqual(pred.width, 1)

# multivariate model + multivariate series
with self.assertRaises(ValueError):
backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False)
tcn_model = TCNModel(batch_size=1, n_epochs=1, input_size=2, output_length=3)
with self.assertRaises(ValueError):
backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False,
use_full_output_length=False)
pred = backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 1, target_indices=[0],
verbose=False)
self.assertEqual(pred.width, 1)
pred = backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False,
use_full_output_length=True, target_indices=[1])
self.assertEqual(pred.width, 1)
tcn_model = TCNModel(batch_size=1, n_epochs=1, input_size=2, output_length=3, output_size=2)
pred = backtest_forecasting(linear_series_multi, tcn_model, pd.Timestamp('20000125'), 3, verbose=False,
use_full_output_length=True, target_indices=[0, 1])
self.assertEqual(pred.width, 2)

def test_backtest_regression(self):
gaussian_series = gt(mean=2, length=50)
Expand All @@ -88,9 +130,9 @@ def test_backtest_regression(self):
self.assertEqual(r2_score(pred, target), 1.0)

# multivariate target
with self.assertRaises(ValueError):
backtest_regression(features_multivariate, target.stack(target),
StandardRegressionModel(15), pd.Timestamp('20000201'), 3)
pred = backtest_regression(features_multivariate, target.stack(target),
StandardRegressionModel(15), pd.Timestamp('20000201'), 3)
self.assertEqual(r2_score(pred, target.stack(target)), 1.0)

def test_backtest_gridsearch(self):

Expand All @@ -109,6 +151,19 @@ def test_backtest_gridsearch(self):
es_params = {'seasonal_periods': list(range(5, 10))}
self.assertTrue(compare_best_against_random(ExponentialSmoothing, es_params, dummy_series))

def test_backtest_gridsearch_multi(self):
dummy_series = st(length=40, value_y_offset=10).stack(lt(length=40, end_value=20))
tcn_params = {
'n_epochs': [1],
'batch_size': [1],
'input_size': [2],
'output_length': [3],
'output_size': [2],
'kernel_size': [2, 3, 4]
}
backtest_gridsearch(TCNModel, tcn_params, dummy_series, fcast_horizon_n=3, metric=mape,
use_full_output_length=True, target_indices=[0, 1])

def test_forecasting_residuals(self):
model = NaiveSeasonal(K=1)

Expand Down