Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/model use static cov #1700

Merged
merged 15 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from darts.models.forecasting.random_forest import RandomForest
from darts.models.forecasting.regression_ensemble_model import RegressionEnsembleModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.models.forecasting.tbats import BATS, TBATS
from darts.models.forecasting.tbats_model import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.varima import VARIMA

Expand Down
4 changes: 0 additions & 4 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,6 @@ def __init__(
self.n_rnn_layers = n_rnn_layers
self.dropout = dropout

@staticmethod
def _supports_static_covariates() -> bool:
return False

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
6 changes: 6 additions & 0 deletions darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
quantiles: List = None,
random_state: Optional[int] = None,
multi_models: Optional[bool] = True,
use_static_covariates: bool = True,
**kwargs,
):
"""CatBoost Model
Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(
multi_models
If True, a separate model will be trained for each future lag to predict. If False, a single model is
trained to predict at step 'output_chunk_length' in the future. Default: True.
use_static_covariates
Whether the model should use static covariate information in case the input `series` passed to ``fit()``
contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()`.
**kwargs
Additional keyword arguments passed to `catboost.CatBoostRegressor`.
"""
Expand Down Expand Up @@ -132,6 +137,7 @@ def __init__(
add_encoders=add_encoders,
multi_models=multi_models,
model=CatBoostRegressor(**kwargs),
use_static_covariates=use_static_covariates,
)

def fit(
Expand Down
10 changes: 8 additions & 2 deletions darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __init__(
shared_weights: bool = False,
kernel_size: int = 25,
const_init: bool = True,
use_static_covariates: bool = True,
**kwargs,
):
"""An implementation of the DLinear model, as presented in [1]_.
Expand Down Expand Up @@ -265,6 +266,10 @@ def __init__(
const_init
Whether to initialize the weights to 1/in_len. If False, the default PyTorch
initialization is used (default='True').
use_static_covariates
Whether the model should use static covariate information in case the input `series` passed to ``fit()``
contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()`.
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -406,6 +411,7 @@ def __init__(
self.shared_weights = shared_weights
self.kernel_size = kernel_size
self.const_init = const_init
self._considers_static_covariates = use_static_covariates

def _create_model(
self, train_sample: MixedCovariatesTrainTensorType
Expand Down Expand Up @@ -446,6 +452,6 @@ def _create_model(
**self.pl_module_params,
)

@staticmethod
def _supports_static_covariates() -> bool:
@property
def supports_static_covariates(self) -> bool:
return True
71 changes: 63 additions & 8 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ def __init__(self, *args, **kwargs):
self.future_covariate_series: Optional[TimeSeries] = None
self.static_covariates: Optional[pd.DataFrame] = None

self._expect_past_covariates, self._expect_future_covariates = False, False
self._uses_past_covariates, self._uses_future_covariates = False, False
self._expect_past_covariates, self._uses_past_covariates = False, False
self._expect_future_covariates, self._uses_future_covariates = False, False
# for static covariates there is the option to consider static covariates or ignore them
self._considers_static_covariates = False
self._expect_static_covariates, self._uses_static_covariates = False, False

# state; whether the model has been fit (on a single time series)
self._fit_called = False
Expand Down Expand Up @@ -192,12 +195,25 @@ def _supports_non_retrainable_historical_forecasts(self) -> bool:

@property
def supports_past_covariates(self):
"""
Whether model supports past covariates
"""
return "past_covariates" in inspect.signature(self.fit).parameters.keys()

@property
def supports_future_covariates(self):
"""
Whether model supports future covariates
"""
return "future_covariates" in inspect.signature(self.fit).parameters.keys()

@property
def supports_static_covariates(self) -> bool:
"""
Whether model supports static covariates
"""
return False

@property
def uses_past_covariates(self):
"""
Expand All @@ -212,6 +228,20 @@ def uses_future_covariates(self):
"""
return self._uses_future_covariates

@property
def uses_static_covariates(self):
"""
Whether the model uses static covariates, once fitted.
"""
return self._uses_static_covariates

@property
def considers_static_covariates(self):
"""
Whether the model considers static covariates, if there are any.
"""
return self._considers_static_covariates

@abstractmethod
def predict(self, n: int, num_samples: int = 1) -> TimeSeries:
"""Forecasts values for `n` time steps after the end of the training series.
Expand Down Expand Up @@ -1994,24 +2024,39 @@ def fit(
if isinstance(series, TimeSeries):
# if only one series is provided, save it for prediction time (including covariates, if available)
self.training_series = series
self.static_covariates = series.static_covariates
if past_covariates is not None:
self.past_covariate_series = past_covariates
if future_covariates is not None:
self.future_covariate_series = future_covariates
if (
series.static_covariates is not None
and self.supports_static_covariates
and self.considers_static_covariates
):
self.static_covariates = series.static_covariates
else:
self.static_covariates = series[0].static_covariates

if past_covariates is not None:
self._expect_past_covariates = True
if future_covariates is not None:
self._expect_future_covariates = True
if (
get_single_series(series).static_covariates is not None
and self.supports_static_covariates
and self.considers_static_covariates
):
self.static_covariates = series[0].static_covariates
self._expect_static_covariates = True

if past_covariates is not None:
self._uses_past_covariates = True
if future_covariates is not None:
self._uses_future_covariates = True

if (
get_single_series(series).static_covariates is not None
and self.supports_static_covariates
and self.considers_static_covariates
):
self._uses_static_covariates = True
self._fit_called = True

@abstractmethod
Expand Down Expand Up @@ -2067,20 +2112,30 @@ def predict(
a sequence where each element contains the corresponding `n` points forecasts.
"""
super().predict(n, num_samples)
if self._expect_past_covariates and past_covariates is None:
if self.uses_past_covariates and past_covariates is None:
raise_log(
ValueError(
"The model has been trained with past covariates. Some matching past_covariates "
"have to be provided to `predict()`."
)
)
if self._expect_future_covariates and future_covariates is None:
if self.uses_future_covariates and future_covariates is None:
raise_log(
ValueError(
"The model has been trained with future covariates. Some matching future_covariates "
"have to be provided to `predict()`."
)
)
if (
self.uses_static_covariates
and get_single_series(series).static_covariates is None
):
raise_log(
ValueError(
"The model has been trained with static covariates. Some matching static covariates "
"must be embedded in the target `series` passed to `predict()`."
)
)

def _predict_wrapper(
self,
Expand Down
6 changes: 6 additions & 0 deletions darts/models/forecasting/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
quantiles: List[float] = None,
random_state: Optional[int] = None,
multi_models: Optional[bool] = True,
use_static_covariates: bool = True,
categorical_past_covariates: Optional[Union[str, List[str]]] = None,
categorical_future_covariates: Optional[Union[str, List[str]]] = None,
categorical_static_covariates: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -93,6 +94,10 @@ def __init__(
multi_models
If True, a separate model will be trained for each future lag to predict. If False, a single model is
trained to predict at step 'output_chunk_length' in the future. Default: True.
use_static_covariates
Whether the model should use static covariate information in case the input `series` passed to ``fit()``
contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()`.
categorical_past_covariates
Optionally, component name or list of component names specifying the past covariates that should be treated
as categorical by the underlying `lightgbm.LightGBMRegressor`. It's recommended that the components that
Expand Down Expand Up @@ -137,6 +142,7 @@ def __init__(
add_encoders=add_encoders,
multi_models=multi_models,
model=lgb.LGBMRegressor(**self.kwargs),
use_static_covariates=use_static_covariates,
categorical_past_covariates=categorical_past_covariates,
categorical_future_covariates=categorical_future_covariates,
categorical_static_covariates=categorical_static_covariates,
Expand Down
6 changes: 6 additions & 0 deletions darts/models/forecasting/linear_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
quantiles: List[float] = None,
random_state: Optional[int] = None,
multi_models: Optional[bool] = True,
use_static_covariates: bool = True,
**kwargs,
):
"""Linear regression model.
Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(
multi_models
If True, a separate model will be trained for each future lag to predict. If False, a single model is
trained to predict at step 'output_chunk_length' in the future. Default: True.
use_static_covariates
Whether the model should use static covariate information in case the input `series` passed to ``fit()``
contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()`.
**kwargs
Additional keyword arguments passed to `sklearn.linear_model.LinearRegression` (by default), to
`sklearn.linear_model.PoissonRegressor` (if `likelihood="poisson"`), or to
Expand Down Expand Up @@ -122,6 +127,7 @@ def __init__(
add_encoders=add_encoders,
model=model,
multi_models=multi_models,
use_static_covariates=use_static_covariates,
)

def fit(
Expand Down
4 changes: 0 additions & 4 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,10 +755,6 @@ def __init__(
if isinstance(layer_widths, int):
self.layer_widths = [layer_widths] * self.num_stacks

@staticmethod
def _supports_static_covariates() -> bool:
return False

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
4 changes: 0 additions & 4 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,10 +750,6 @@ def _check_sizes(tup, name):

return pooling_kernel_sizes, n_freq_downsample

@staticmethod
def _supports_static_covariates() -> bool:
return False

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
11 changes: 8 additions & 3 deletions darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
shared_weights: bool = False,
const_init: bool = True,
normalize: bool = False,
use_static_covariates: bool = True,
**kwargs,
):
"""An implementation of the NLinear model, as presented in [1]_.
Expand Down Expand Up @@ -216,7 +217,10 @@ def __init__(
.. note::
This cannot be applied to probabilistic models.
..

use_static_covariates
Whether the model should use static covariate information in case the input `series` passed to ``fit()``
contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()`.
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -356,6 +360,7 @@ def __init__(
self.shared_weights = shared_weights
self.const_init = const_init
self.normalize = normalize
self._considers_static_covariates = use_static_covariates

raise_if(
"likelihood" in self.model_params
Expand Down Expand Up @@ -402,6 +407,6 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
**self.pl_module_params,
)

@staticmethod
def _supports_static_covariates() -> bool:
@property
def supports_static_covariates(self) -> bool:
return True
6 changes: 6 additions & 0 deletions darts/models/forecasting/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
n_estimators: Optional[int] = 100,
max_depth: Optional[int] = None,
multi_models: Optional[bool] = True,
use_static_covariates: bool = True,
**kwargs,
):
"""Random Forest Model
Expand Down Expand Up @@ -85,6 +86,10 @@ def __init__(
multi_models
If True, a separate model will be trained for each future lag to predict. If False, a single model is
trained to predict at step 'output_chunk_length' in the future. Default: True.
use_static_covariates
Whether the model should use static covariate information in case the input `series` passed to ``fit()``
contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce
that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()`.
**kwargs
Additional keyword arguments passed to `sklearn.ensemble.RandomForest`.
"""
Expand All @@ -102,4 +107,5 @@ def __init__(
add_encoders=add_encoders,
multi_models=multi_models,
model=RandomForestRegressor(**kwargs),
use_static_covariates=use_static_covariates,
)