Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add supports_multivariate property to ForecastingModel #1848

Merged
merged 12 commits into from
Jul 5, 2023
4 changes: 4 additions & 0 deletions darts/models/forecasting/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def __init__(
self._random_state = None
np.random.seed(random_state if random_state is not None else 0)

@property
def supports_multivariate(self) -> bool:
return False

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
super()._fit(series, future_covariates)

Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __init__(
self.model = PmdAutoARIMA(*autoarima_args, **autoarima_kwargs)
self.trend = self.model.trend

@property
def supports_multivariate(self) -> bool:
return False

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
super()._fit(series, future_covariates)
self._assert_univariate(series)
Expand Down
16 changes: 16 additions & 0 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def __init__(self):
super().__init__()
self.mean_val = None

@property
def supports_multivariate(self) -> bool:
return True

def fit(self, series: TimeSeries):
super().fit(series)

Expand Down Expand Up @@ -59,6 +63,10 @@ def __init__(self, K: int = 1):
self.last_k_vals = None
self.K = K

@property
def supports_multivariate(self) -> bool:
return True

@property
def min_train_series_length(self):
return max(self.K, 3)
Expand Down Expand Up @@ -91,6 +99,10 @@ def __init__(self):
"""
super().__init__()

@property
def supports_multivariate(self) -> bool:
return True

def fit(self, series: TimeSeries):
super().fit(series)
assert series.n_samples == 1, "This model expects deterministic time series"
Expand Down Expand Up @@ -125,6 +137,10 @@ def __init__(self, input_chunk_length: int = 1):
self.input_chunk_length = input_chunk_length
self.rolling_window = None

@property
def supports_multivariate(self) -> bool:
return True

@property
def min_train_series_length(self):
return self.input_chunk_length
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ def __init__(
self.n_rnn_layers = n_rnn_layers
self.dropout = dropout

@property
def supports_multivariate(self) -> bool:
return True
felixdivo marked this conversation as resolved.
Show resolved Hide resolved

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def __init__(
use_static_covariates=use_static_covariates,
)

@property
def supports_multivariate(self) -> bool:
return True

felixdivo marked this conversation as resolved.
Show resolved Hide resolved
def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/croston.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(

self.version = version

@property
def supports_multivariate(self) -> bool:
return False

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
super()._fit(series, future_covariates)
self._assert_univariate(series)
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ def _create_model(
**self.pl_module_params,
)

@property
def supports_multivariate(self) -> bool:
return True

@property
def supports_static_covariates(self) -> bool:
return True
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ def _models_are_probabilistic(self) -> bool:
def _is_probabilistic(self) -> bool:
return self._models_are_probabilistic()

@property
def supports_multivariate(self) -> bool:
return all([model.supports_multivariate for model in self.models])
felixdivo marked this conversation as resolved.
Show resolved Hide resolved

@property
def supports_past_covariates(self) -> bool:
return any([model.supports_past_covariates for model in self.models])
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def predict(self, n, num_samples=1, verbose: bool = False):

return self._build_forecast_series(forecast)

@property
def supports_multivariate(self) -> bool:
return True
felixdivo marked this conversation as resolved.
Show resolved Hide resolved

def _is_probabilistic(self) -> bool:
return True

Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def __init__(
self.trend = trend
self.trend_poly_degree = trend_poly_degree

@property
def supports_multivariate(self) -> bool:
return False

def _exp_trend(self, x) -> Callable:
"""Helper function, used to make FFT model pickable."""
return np.exp(self.trend_coefficients[1]) * np.exp(
Expand Down
20 changes: 14 additions & 6 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self, *args, **kwargs):
self.future_covariate_series: Optional[TimeSeries] = None
self.static_covariates: Optional[pd.DataFrame] = None

self._supports_multivariate = False
felixdivo marked this conversation as resolved.
Show resolved Hide resolved
self._expect_past_covariates, self._uses_past_covariates = False, False
self._expect_future_covariates, self._uses_future_covariates = False, False
# for static covariates there is the option to consider static covariates or ignore them
Expand Down Expand Up @@ -195,14 +196,21 @@ def _supports_non_retrainable_historical_forecasts(self) -> bool:
return False

@property
def supports_past_covariates(self):
@abstractmethod
def supports_multivariate(self) -> bool:
"""
Whether the model considers more than one variate in the time series.
felixdivo marked this conversation as resolved.
Show resolved Hide resolved
"""

@property
def supports_past_covariates(self) -> bool:
"""
Whether model supports past covariates
"""
return "past_covariates" in inspect.signature(self.fit).parameters.keys()

@property
def supports_future_covariates(self):
def supports_future_covariates(self) -> bool:
"""
Whether model supports future covariates
"""
Expand All @@ -216,28 +224,28 @@ def supports_static_covariates(self) -> bool:
return False

@property
def uses_past_covariates(self):
def uses_past_covariates(self) -> bool:
"""
Whether the model uses past covariates, once fitted.
"""
return self._uses_past_covariates

@property
def uses_future_covariates(self):
def uses_future_covariates(self) -> bool:
"""
Whether the model uses future covariates, once fitted.
"""
return self._uses_future_covariates

@property
def uses_static_covariates(self):
def uses_static_covariates(self) -> bool:
"""
Whether the model uses static covariates, once fitted.
"""
return self._uses_static_covariates

@property
def considers_static_covariates(self):
def considers_static_covariates(self) -> bool:
"""
Whether the model considers static covariates, if there are any.
"""
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/kalman_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,9 @@ def _predict(

return filtered_series[-n:]

@property
def supports_multivariate(self) -> bool:
return True

def _is_probabilistic(self) -> bool:
return True
4 changes: 4 additions & 0 deletions darts/models/forecasting/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def _predict_and_sample(
else:
return super()._predict_and_sample(x, num_samples, **kwargs)

@property
def supports_multivariate(self) -> bool:
return True

felixdivo marked this conversation as resolved.
Show resolved Hide resolved
def _is_probabilistic(self) -> bool:
return self.likelihood is not None

Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/linear_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,9 @@ def _predict_and_sample(
else:
return super()._predict_and_sample(x, num_samples, **kwargs)

@property
def supports_multivariate(self) -> bool:
felixdivo marked this conversation as resolved.
Show resolved Hide resolved
return True

def _is_probabilistic(self) -> bool:
return self.likelihood is not None
4 changes: 4 additions & 0 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,10 @@ def __init__(
if isinstance(layer_widths, int):
self.layer_widths = [layer_widths] * self.num_stacks

@property
def supports_multivariate(self) -> bool:
return True

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,10 @@ def __init__(
if isinstance(layer_widths, int):
self.layer_widths = [layer_widths] * self.num_stacks

@property
def supports_multivariate(self) -> bool:
return True

@staticmethod
def _prepare_pooling_downsampling(
pooling_kernel_sizes, n_freq_downsample, in_len, out_len, num_blocks, num_stacks
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
**self.pl_module_params,
)

@property
def supports_multivariate(self) -> bool:
return True

@property
def supports_static_covariates(self) -> bool:
return True
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def _generate_predict_df(
)
return predict_df

@property
def supports_multivariate(self) -> bool:
return False

def _is_probabilistic(self) -> bool:
return True

Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ def __init__(
model=RandomForestRegressor(**kwargs),
use_static_covariates=use_static_covariates,
)

@property
def supports_multivariate(self) -> bool:
return True
felixdivo marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 6 additions & 2 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def fit(
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
):

super().fit(
series, past_covariates=past_covariates, future_covariates=future_covariates
)
Expand Down Expand Up @@ -189,7 +188,6 @@ def ensemble(
series: Optional[Sequence[TimeSeries]] = None,
num_samples: int = 1,
) -> Union[TimeSeries, Sequence[TimeSeries]]:

is_single_series = isinstance(series, TimeSeries) or series is None
predictions = series2seq(predictions)
series = series2seq(series) if series is not None else [None]
Expand Down Expand Up @@ -219,6 +217,12 @@ def extreme_lags(
extreme_lags_ = super().extreme_lags
return (extreme_lags_[0] - self.train_n_points,) + extreme_lags_[1:]

@property
def supports_multivariate(self) -> bool:
return super().supports_multivariate and getattr(
self.regression_model, "supports_multivariate", True
)
felixdivo marked this conversation as resolved.
Show resolved Hide resolved

def _is_probabilistic(self) -> bool:
"""
A RegressionEnsembleModel is probabilistic if its regression
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ def extreme_lags(
max_future_cov_lag,
)

@property
def supports_multivariate(self) -> bool:
felixdivo marked this conversation as resolved.
Show resolved Hide resolved
return True

@property
def min_train_series_length(self) -> int:
return max(
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ def _verify_train_dataset_type(self, train_dataset: TrainingDataset):
"RNNModel requires a shifted training dataset with shift=1.",
)

@property
def supports_multivariate(self) -> bool:
return True

@property
def min_train_series_length(self) -> int:
return self.training_length + 1
4 changes: 4 additions & 0 deletions darts/models/forecasting/sf_auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def _predict(

return self._build_forecast_series(samples)

@property
def supports_multivariate(self) -> bool:
return False

@property
def min_train_series_length(self) -> int:
return 10
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/sf_auto_ces.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def predict(

return self._build_forecast_series(mu)

@property
def supports_multivariate(self) -> bool:
return False

@property
def min_train_series_length(self) -> int:
return 10
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/sf_auto_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def _predict(
samples = mu
return self._build_forecast_series(samples)

@property
def supports_multivariate(self) -> bool:
return False

@property
def min_train_series_length(self) -> int:
return 10
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/sf_auto_theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def predict(

return self._build_forecast_series(samples)

@property
def supports_multivariate(self) -> bool:
return False

@property
def min_train_series_length(self) -> int:
return 10
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/tbats_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def predict(self, n, num_samples=1, verbose: bool = False):

return self._build_forecast_series(samples)

@property
def supports_multivariate(self) -> bool:
return False

def _is_probabilistic(self) -> bool:
return True

Expand Down