Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] pytorch forecasting adapter with Global Forecasting API #6228

Merged
merged 92 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
0278280
pytorch-forecasting first draft
Xinyu-Wu-0000 Mar 28, 2024
20c52b0
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Mar 28, 2024
4aaa022
set None params to default value
Xinyu-Wu-0000 Mar 29, 2024
7dd2a2f
convert X, y to TimeSeriesDataSet in fit
Xinyu-Wu-0000 Mar 30, 2024
670f53c
fix monotone_constaints is None
Xinyu-Wu-0000 Mar 31, 2024
f68fb1c
add dataset_params
Xinyu-Wu-0000 Mar 31, 2024
f0ac35c
add to_dataloader_params
Xinyu-Wu-0000 Mar 31, 2024
473d606
train validation split by max_prediction_length
Xinyu-Wu-0000 Mar 31, 2024
17b02e2
fix kwargs overwrite in model.from_dataset
Xinyu-Wu-0000 Mar 31, 2024
861182a
fix soft dependencies import error
Xinyu-Wu-0000 Mar 31, 2024
1e233bd
data convertion in predict
Xinyu-Wu-0000 Apr 1, 2024
1bd91f2
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Apr 1, 2024
ef946ca
fix unsupported operand type(s) for |
Xinyu-Wu-0000 Apr 1, 2024
abefa3a
fix kwargs loss after reset
Xinyu-Wu-0000 Apr 6, 2024
1341b36
fix output y name
Xinyu-Wu-0000 Apr 6, 2024
f78e635
add comments
Xinyu-Wu-0000 Apr 14, 2024
d3fb3c8
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Apr 14, 2024
8ee461a
rename GlobalBaseForecaster to BaseGlobalForecaster
Xinyu-Wu-0000 Apr 19, 2024
f72fc41
add global_forecaster tag
Xinyu-Wu-0000 May 5, 2024
af4fce4
add BaseGlobalForecaster to BASE_CLASS_REGISTER
Xinyu-Wu-0000 May 5, 2024
6525392
add TestAllGlobalForecasters
Xinyu-Wu-0000 May 5, 2024
c3c43de
BaseGlobalForecaster as an exception in test_inheritance
Xinyu-Wu-0000 May 5, 2024
17ec882
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 May 5, 2024
e7e7205
set _tags before init (add "global_forecaster" tag)
Xinyu-Wu-0000 May 5, 2024
80e4005
add capability:global_forecasting to PytorchForecastingTFT
Xinyu-Wu-0000 May 8, 2024
ec800bf
fix global_forecasting check in base class
Xinyu-Wu-0000 May 8, 2024
4adf5b6
register capability:global_forecasting tag
Xinyu-Wu-0000 May 8, 2024
5383bcd
test_global_forecasting_tag
Xinyu-Wu-0000 May 8, 2024
b5de8f0
add pytorch-forecasting to pyproject.toml
Xinyu-Wu-0000 May 8, 2024
433acbc
test_pridect_signature
Xinyu-Wu-0000 May 9, 2024
71f1477
fix empty param dict not work
Xinyu-Wu-0000 May 14, 2024
9361b3c
fix no pd.Series support for y_inner_mtype
Xinyu-Wu-0000 May 14, 2024
1ada234
add requires_X tag for PytorchForecastingTFT
Xinyu-Wu-0000 May 14, 2024
0824d27
fix time index not integer error
Xinyu-Wu-0000 May 15, 2024
922a6c8
add get_test_params for PytorchForecastingTFT
Xinyu-Wu-0000 May 17, 2024
5e15bfd
test_global_fit_predict_insample
Xinyu-Wu-0000 May 17, 2024
9f6dc57
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 May 17, 2024
d96e54e
fix soft dependencies in get_test_param
Xinyu-Wu-0000 May 17, 2024
8cc520d
add target to time_varying_unknown_reals
Xinyu-Wu-0000 May 20, 2024
ac9a211
add PytorchForecastingNBeats
Xinyu-Wu-0000 May 22, 2024
2769321
big fix to support defferent input
Xinyu-Wu-0000 May 22, 2024
aeb0862
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 May 22, 2024
944f9b7
fix encoder length to long
Xinyu-Wu-0000 May 23, 2024
9097abf
fix absolute fh to max prediction length
Xinyu-Wu-0000 May 23, 2024
e642e53
fix x is pd.Series
Xinyu-Wu-0000 May 23, 2024
6f5ceda
add y to _vectorize in predict
Xinyu-Wu-0000 May 23, 2024
58f688f
dummy X for TFT if X is None
Xinyu-Wu-0000 May 23, 2024
ec7d9c7
fix QuantileLoss not pass test_set_params
Xinyu-Wu-0000 May 27, 2024
0e7e22a
self._y in predict if y=None
Xinyu-Wu-0000 May 27, 2024
386e01b
capability:pred_int:insample
Xinyu-Wu-0000 May 27, 2024
0adadd0
fix fh not continue
Xinyu-Wu-0000 May 29, 2024
7b55f08
fix column name in _series_to_frame
Xinyu-Wu-0000 May 29, 2024
f19e59e
fix overwrite self._X self._y
Xinyu-Wu-0000 May 29, 2024
fac46a0
move global_forecasting tag check to the top
Xinyu-Wu-0000 May 29, 2024
4addd77
extend y and concat x
Xinyu-Wu-0000 May 29, 2024
3852184
wordaround for max_prediction_length=1 problem in CI
Xinyu-Wu-0000 May 30, 2024
07acbd6
test_global_forecasting_multiindex_hier
Xinyu-Wu-0000 May 30, 2024
2ac25f1
test_global_forecasting_multiindex
Xinyu-Wu-0000 May 30, 2024
6aec47d
test_global_forecasting_series
Xinyu-Wu-0000 May 30, 2024
c2781a2
test_global_forecasting_no_X
Xinyu-Wu-0000 May 30, 2024
7511ed1
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 May 30, 2024
df2303f
add version constraint
Xinyu-Wu-0000 May 31, 2024
9d940e7
clean fix
Xinyu-Wu-0000 Jun 2, 2024
9ca9709
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Jun 2, 2024
ce39727
Merge branch 'main' into pr/6228
fkiraly Jun 3, 2024
7c8907d
Update pytorchforecasting.py
fkiraly Jun 3, 2024
caa0b14
Update forecasting.rst
Xinyu-Wu-0000 Jun 3, 2024
a332c61
Merge branch 'main' into pr/6228
fkiraly Jun 3, 2024
b21a336
fix for review by benHeid
Xinyu-Wu-0000 Jun 4, 2024
7f3c9c1
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Jun 4, 2024
3a60a67
[AUTOMATED] update CONTRIBUTORS.md
Xinyu-Wu-0000 Jun 4, 2024
49978d0
Merge branch 'main' into pr/6228
fkiraly Jun 7, 2024
54530db
Merge branch 'global_pytorch_forecasting' of https://github.com/Xinyu…
fkiraly Jun 7, 2024
677ecdb
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Jun 10, 2024
3890cf8
global forecast docstring in predict
Xinyu-Wu-0000 Jun 11, 2024
bdab849
improve CI test
Xinyu-Wu-0000 Jun 11, 2024
aed8476
add underscore to BaseGlobalForecaster
Xinyu-Wu-0000 Jun 11, 2024
c6bfddd
NotImplementedError for in sample predict
Xinyu-Wu-0000 Jun 11, 2024
c6728f6
set tag dict explicitly
Xinyu-Wu-0000 Jun 11, 2024
87ba617
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Jun 12, 2024
d5ee476
pd.Dataframe inner type
Xinyu-Wu-0000 Jun 12, 2024
3587590
fix NaN in target column if fh is not continuous
Xinyu-Wu-0000 Jun 12, 2024
a7a36ed
fix FileNotFoundError in CI
Xinyu-Wu-0000 Jun 13, 2024
dcb5f7f
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Jun 13, 2024
fa20a44
Merge branch 'main' into pr/6228
fkiraly Jun 13, 2024
49638af
Merge branch 'global_pytorch_forecasting' of https://github.com/Xinyu…
fkiraly Jun 13, 2024
59a4b5c
fix non-continuous fh in extend_y
Xinyu-Wu-0000 Jun 14, 2024
4d2ce8c
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Jun 14, 2024
ab00927
random log dir in CI
Xinyu-Wu-0000 Jun 15, 2024
10a59ac
Merge branch 'main' into global_pytorch_forecasting
Xinyu-Wu-0000 Jun 19, 2024
284cbae
is_all_out_of_sample
Xinyu-Wu-0000 Jun 17, 2024
bbeaa56
fix no attribute '_random_log_dir'
Xinyu-Wu-0000 Jun 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ pip-log.txt
pip-delete-this-directory.txt
pip-wheel-metadata/

# Training logs
lightning_logs/

# folder created by `make test`
testdir/

Expand Down
9 changes: 9 additions & 0 deletions docs/source/api_reference/forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,15 @@ Deep learning based forecasters
NeuralForecastRNN
NeuralForecastLSTM

.. currentmodule:: sktime.forecasting.pytorchforecasting

.. autosummary::
:toctree: auto_generated/
:template: class.rst

PytorchForecastingTFT
PytorchForecastingNBeats

.. currentmodule:: sktime.forecasting.pykan_forecaster

.. autosummary::
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ dl = [
'tensorflow<2.17,>=2; python_version < "3.12"',
'torch; python_version < "3.12"',
'transformers[torch]<4.41.0; python_version < "3.12"',
'pykan; python_version > "3.9.7"'
'pykan; python_version > "3.9.7"',
'pytorch-forecasting>=1.0.0; python_version < "3.11"',
]
mlflow = [
"mlflow",
Expand Down
3 changes: 2 additions & 1 deletion sktime/forecasting/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
__all__ = [
"ForecastingHorizon",
"BaseForecaster",
"_BaseGlobalForecaster",
]

from sktime.forecasting.base._base import BaseForecaster
from sktime.forecasting.base._base import BaseForecaster, _BaseGlobalForecaster
from sktime.forecasting.base._fh import ForecastingHorizon
147 changes: 146 additions & 1 deletion sktime/forecasting/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class name: BaseForecaster

__author__ = ["mloning", "big-o", "fkiraly", "sveameyer13", "miraep8", "ciaran-g"]

__all__ = ["BaseForecaster"]
__all__ = ["BaseForecaster", "_BaseGlobalForecaster"]

from copy import deepcopy
from itertools import product
Expand Down Expand Up @@ -2540,6 +2540,151 @@ def _get_columns(self, method="predict", **kwargs):
BaseForecaster._init_dynamic_doc()


class _BaseGlobalForecaster(BaseForecaster):
"""Base global forecaster template class.

This class is a temporal solution, might be merged into BaseForecaster later.

The base forecaster specifies the methods and method signatures that all
global forecasters have to implement.

Specific implementations of these methods is deferred to concrete forecasters.

"""

_tags = {"object_type": ["global_forecaster", "forecaster"]}

def predict(self, fh=None, X=None, y=None):
"""Forecast time series at future horizon.

State required:
Requires state to be "fitted", i.e., ``self.is_fitted=True``.

Accesses in self:

* Fitted model attributes ending in "_".
* ``self.cutoff``, ``self.is_fitted``

Writes to self:
Stores ``fh`` to ``self.fh`` if ``fh`` is passed and has not been passed
previously.

Parameters
----------
fh : int, list, np.array or ``ForecastingHorizon``, optional (default=None)
The forecasting horizon encoding the time stamps to forecast at.
Should not be passed if has already been passed in ``fit``.
If has not been passed in fit, must be passed, not optional

X : time series in ``sktime`` compatible format, optional (default=None)
Exogeneous time series to use in prediction.
Should be of same scitype (``Series``, ``Panel``, or ``Hierarchical``)
as ``y`` in ``fit``.
If ``self.get_tag("X-y-must-have-same-index")``,
``X.index`` must contain ``fh`` index reference.
If ``y`` is not passed (not performing global forecasting), ``X`` should
only contain the time points to be predicted.
If ``y`` is passed (performing global forecasting), ``X`` must contain
all historical values and the time points to be predicted.

y : time series in ``sktime`` compatible format, optional (default=None)
Historical values of the time series that should be predicted.
If not None, global forecasting will be performed.
Only pass the historical values not the time points to be predicted.

Returns
-------
y_pred : time series in sktime compatible data container format
Point forecasts at ``fh``, with same index as ``fh``.
``y_pred`` has same type as the ``y`` that has been passed most recently:
``Series``, ``Panel``, ``Hierarchical`` scitype, same format (see above)

Notes
-----
If ``y`` is not None, global forecast will be performed.
In global forecast mode,
``X`` should contain all historical values and the time points to be predicted,
while ``y`` should only contain historical values
not the time points to be predicted.

If ``y`` is None, non global forecast will be performed.
In non global forecast mode,
``X`` should only contain the time points to be predicted,
while ``y`` should only contain historical values
not the time points to be predicted.
"""
# check global forecasting tag
gf = self.get_tag(
"capability:global_forecasting", tag_value_default=False, raise_error=False
)
if not gf and y is not None:
ValueError("no global forecasting support!")

# handle inputs
self.check_is_fitted()
if y is None:
self._global_forecasting = False
else:
self._global_forecasting = True
# check and convert X/y
X_inner, y_inner = self._check_X_y(X=X, y=y)

# this also updates cutoff from y
# be cautious, in fit self._X and self._y is also updated but not here!
if y_inner is not None:
self._set_cutoff_from_y(y_inner)

# check fh and coerce to ForecastingHorizon, if not already passed in fit
fh = self._check_fh(fh)

# we call the ordinary _predict if no looping/vectorization needed
if not self._is_vectorized:
y_pred = self._predict(fh=fh, X=X_inner, y=y_inner)
else:
# otherwise we call the vectorized version of predict
y_pred = self._vectorize("predict", y=y_inner, X=X_inner, fh=fh)

# convert to output mtype, identical with last y mtype seen
y_out = convert_to(
y_pred,
self._y_metadata["mtype"],
store=self._converter_store_y,
store_behaviour="freeze",
)

return y_out

def _predict(self, fh, X, y):
"""Forecast time series at future horizon.

private _predict containing the core logic, called from predict

State required:
Requires state to be "fitted".

Accesses in self:
Fitted model attributes ending in "_"
self.cutoff

Parameters
----------
fh : guaranteed to be ForecastingHorizon or None, optional (default=None)
The forecasting horizon with the steps ahead to to predict.
If not passed in _fit, guaranteed to be passed here
X : optional (default=None)
guaranteed to be of a type in self.get_tag("X_inner_mtype")
Exogeneous time series for the forecast
y : time series in ``sktime`` compatible format, optional (default=None)
Historical values of the time series that should be predicted.

Returns
-------
y_pred : pd.Series
Point predictions
"""
raise NotImplementedError("abstract method")


def _format_moving_cutoff_predictions(y_preds, cutoffs):
"""Format moving-cutoff predictions.

Expand Down
4 changes: 4 additions & 0 deletions sktime/forecasting/base/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"_StatsForecastAdapter",
"_GeneralisedStatsForecastAdapter",
"_NeuralForecastAdapter",
"_PytorchForecastingAdapter",
]

from sktime.forecasting.base.adapters._fbprophet import _ProphetAdapter
Expand All @@ -17,6 +18,9 @@
)
from sktime.forecasting.base.adapters._neuralforecast import _NeuralForecastAdapter
from sktime.forecasting.base.adapters._pmdarima import _PmdArimaAdapter
from sktime.forecasting.base.adapters._pytorchforecasting import (
_PytorchForecastingAdapter,
)
from sktime.forecasting.base.adapters._statsforecast import _StatsForecastAdapter
from sktime.forecasting.base.adapters._statsmodels import _StatsModelsAdapter
from sktime.forecasting.base.adapters._tbats import _TbatsAdapter
Loading
Loading