Skip to content

Commit

Permalink
Merge branch 'main' into pr/6265
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed May 22, 2024
2 parents 379e077 + 1986820 commit 030678d
Show file tree
Hide file tree
Showing 13 changed files with 587 additions and 269 deletions.
2 changes: 2 additions & 0 deletions docs/source/api_reference/forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ Deep learning based forecasters
LTSFNLinearForecaster

.. currentmodule:: sktime.forecasting.hf_transformers_forecaster

.. autosummary::
:toctree: auto_generated/
:template: class.rst

Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/split.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ They have tag ``"split_type"="temporal"``.
SingleWindowSplitter
SlidingWindowSplitter
ExpandingWindowSplitter
ExpandingCutoffSplitter
ExpandingGreedySplitter
TemporalTrainTestSplitter

Expand Down
41 changes: 36 additions & 5 deletions docs/source/developer_guide/deprecation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,44 @@
Deprecation
===========

Before we can make changes to sktime's user interface, we need to make sure that users have time to make the necessary adjustments in their code.
For this reason, we first need to deprecate functionality and change it only in a next release.
``sktime`` aims to be stable and reliable towards its users.
Our high-level policy to ensure this is:

.. note::
"``sktime`` should never break user code without a clear and actionable warning
given at least one (MINOR) release cycle in advance."

Here, "break" expressly includes a change to abstract logic, such as the algorithm
being used, not just changes that lead to exceptions or performance degradation.

For instance, if a user has code

.. code:: python
from sktime.forecasting.foo import BarForecaster
bar = BarForecaster(42, x=43)
bar.fit(y_train, fh=[1, 2, 3])
y_pred = bar.predict()
then no release of ``sktime`` should change, without warning:

* import location of ``BarForecaster``
* argument signature of ``BarForecaster``, including name, order, and defaults of arguments
* the abstract algorithm that ``BarForecaster`` carries out for the given arguments

Changes that can be carried out without warning:

* adding more arguments at the end of the argument list, with a default value that retains prior behaviour,
as long as the new arguments are well-documented
* pure refactoring of internal code, as long as the public API remains the same
* changing the implementation without changing the abstract algorithm, e.g., for performance reasons

The deprecation policy outlined in this document provides details on how to carry out
changes that need change or deprecation handling, in a user-friendly and reliable way.

It is accompanied by formulaic patterns for developers, with examples,
and a process for release managers, to make the policy easy to follow.

For upcoming changes and next releases, see our `Milestones <https://github.com/sktime/sktime/milestones?direction=asc&sort=due_date&state=open>`_.
For our long-term plan, see our :ref:`roadmap`.

Deprecation policy
==================
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ all_extras = [
'numba<0.60,>=0.53',
'pmdarima!=1.8.1,<3.0.0,>=1.8; python_version < "3.12"',
'prophet>=1.1; python_version < "3.12"',
"pycatch22<0.4.4",
"pycatch22<0.4.6",
"pykalman<0.10,>=0.9.5",
'pyod>=0.8; python_version < "3.11"',
"pyts<0.14.0; python_version < '3.12'",
Expand Down Expand Up @@ -136,7 +136,7 @@ all_extras_pandas2 = [
'numba<0.60,>=0.53',
'pmdarima!=1.8.1,<3.0.0,>=1.8; python_version < "3.12"',
'prophet>=1.1; python_version < "3.12"',
"pycatch22<0.4.4",
"pycatch22<0.4.6",
"pykalman<0.10,>=0.9.5",
'pyod>=0.8; python_version < "3.11"',
"scikit_posthocs>=0.6.5",
Expand Down Expand Up @@ -207,7 +207,7 @@ regression = [
transformations = [
'esig<0.10,>=0.9.7; python_version < "3.11"',
"filterpy<1.5,>=1.4.5",
"holidays>=0.29,<0.49",
"holidays>=0.29,<0.50",
"mne>=1.5,<1.8",
'numba<0.60,>=0.53',
"pycatch22>=0.4,<0.4.6",
Expand Down
249 changes: 0 additions & 249 deletions sktime/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@

from abc import ABC, abstractmethod

import numpy as np
import pandas as pd

from sktime.base import BaseObject
from sktime.forecasting.base import BaseForecaster
from sktime.utils.validation._dependencies import _check_soft_dependencies

if _check_soft_dependencies("torch", severity="none"):
import torch


class BaseDeepNetwork(BaseObject, ABC):
Expand All @@ -35,244 +27,3 @@ def build_network(self, input_shape, **kwargs):
output_layer : a keras layer
"""
...


class BaseDeepNetworkPyTorch(BaseForecaster, ABC):
"""Abstract base class for deep learning networks using torch.nn."""

_tags = {
"python_dependencies": "torch",
"y_inner_mtype": "pd.DataFrame",
"capability:insample": False,
"capability:pred_int:insample": False,
"scitype:y": "both",
"ignores-exogeneous-X": True,
}

def __init__(
self,
num_epochs=16,
batch_size=8,
in_channels=1,
individual=False,
criterion_kwargs=None,
optimizer=None,
optimizer_kwargs=None,
lr=0.001,
):
self.num_epochs = num_epochs
self.batch_size = batch_size
self.in_channels = in_channels
self.individual = individual
self.criterion_kwargs = criterion_kwargs
self.optimizer = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.lr = lr

super().__init__()

def _fit(self, y, fh, X=None):
"""Fit the network.
Changes to state:
writes to self._network.state_dict
Parameters
----------
X : iterable-style or map-style dataset
see (https://pytorch.org/docs/stable/data.html) for more information
"""
from sktime.forecasting.base import ForecastingHorizon

# save fh and y for prediction later
if fh.is_relative:
self._fh = fh
else:
fh = fh.to_relative(self.cutoff)
self._fh = fh

self._y = y

if type(fh) is ForecastingHorizon:
self.network = self._build_network(fh._values[-1])
else:
self.network = self._build_network(fh)

if self.criterion:
if self.criterion in self.criterions.keys():
if self.criterion_kwargs:
self._criterion = self.criterions[self.criterion](
**self.criterion_kwargs
)
else:
self._criterion = self.criterions[self.criterion]()
else:
raise TypeError(
f"Please pass one of {self.criterions.keys()} for `criterion`."
)
else:
# default criterion
self._criterion = torch.nn.MSELoss()

if self.optimizer:
if self.optimizer in self.optimizers.keys():
if self.optimizer_kwargs:
self._optimizer = self.optimizers[self.optimizer](
self.network.parameters(), lr=self.lr, **self.optimizer_kwargs
)
else:
self._optimizer = self.optimizers[self.optimizer](
self.network.parameters(), lr=self.lr
)
else:
raise TypeError(
f"Please pass one of {self.optimizers.keys()} for `optimizer`."
)
else:
# default optimizer
self._optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr)

dataloader = self.build_pytorch_train_dataloader(y)
self.network.train()

for _ in range(self.num_epochs):
for x, y in dataloader:
y_pred = self.network(x)
loss = self._criterion(y_pred, y)
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()

def _predict(self, X=None, fh=None):
"""Predict with fitted model."""
from torch import cat

if fh is None:
fh = self._fh

if max(fh._values) > self.network.pred_len or min(fh._values) < 0:
raise ValueError(
f"fh of {fh} passed to {self.__class__.__name__} is not "
"within `pred_len`. Please use a fh that aligns with the `pred_len` of "
"the forecaster."
)

if X is None:
dataloader = self.build_pytorch_pred_dataloader(self._y, fh)
else:
dataloader = self.build_pytorch_pred_dataloader(X, fh)

y_pred = []
for x, _ in dataloader:
y_pred.append(self.network(x).detach())
y_pred = cat(y_pred, dim=0).view(-1, y_pred[0].shape[-1]).numpy()
y_pred = y_pred[fh._values.values - 1]
y_pred = pd.DataFrame(
y_pred, columns=self._y.columns, index=fh.to_absolute_index(self.cutoff)
)

return y_pred

def build_pytorch_train_dataloader(self, y):
"""Build PyTorch DataLoader for training."""
from torch.utils.data import DataLoader

if self.custom_dataset_train:
if hasattr(self.custom_dataset_train, "build_dataset") and callable(
self.custom_dataset_train.build_dataset
):
self.custom_dataset_train.build_dataset(y)
dataset = self.custom_dataset_train
else:
raise NotImplementedError(
"Custom Dataset `build_dataset` method is not available. Please "
f"refer to the {self.__class__.__name__}.build_dataset "
"documentation."
)
else:
dataset = PyTorchTrainDataset(
y=y,
seq_len=self.network.seq_len,
fh=self._fh._values[-1],
)

return DataLoader(
dataset,
self.batch_size,
)

def build_pytorch_pred_dataloader(self, y, fh):
"""Build PyTorch DataLoader for prediction."""
from torch.utils.data import DataLoader

if self.custom_dataset_pred:
if hasattr(self.custom_dataset_pred, "build_dataset") and callable(
self.custom_dataset_pred.build_dataset
):
self.custom_dataset_train.build_dataset(y)
dataset = self.custom_dataset_train
else:
raise NotImplementedError(
"Custom Dataset `build_dataset` method is not available. Please"
f"refer to the {self.__class__.__name__}.build_dataset"
"documentation."
)
else:
dataset = PyTorchPredDataset(
y=y[-self.network.seq_len :],
seq_len=self.network.seq_len,
)

return DataLoader(
dataset,
self.batch_size,
)

def get_y_true(self, y):
"""Get y_true values for validation."""
dataloader = self.build_pytorch_pred_dataloader(y)
y_true = [y.flatten().numpy() for _, y in dataloader]
return np.concatenate(y_true, axis=0)


class PyTorchTrainDataset:
"""Dataset for use in sktime deep learning forecasters."""

def __init__(self, y, seq_len, fh):
self.y = y.values
self.seq_len = seq_len
self.fh = fh

def __len__(self):
"""Return length of dataset."""
return len(self.y) - self.seq_len - self.fh + 1

def __getitem__(self, i):
"""Return data point."""
from torch import from_numpy, tensor

return (
tensor(self.y[i : i + self.seq_len]).float(),
from_numpy(self.y[i + self.seq_len : i + self.seq_len + self.fh]).float(),
)


class PyTorchPredDataset:
"""Dataset for use in sktime deep learning forecasters."""

def __init__(self, y, seq_len):
self.y = y.values
self.seq_len = seq_len

def __len__(self):
"""Return length of dataset."""
return 1

def __getitem__(self, i):
"""Return data point."""
from torch import from_numpy, tensor

return (
tensor(self.y[i : i + self.seq_len]).float(),
from_numpy(self.y[i + self.seq_len : i + self.seq_len]).float(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"EmpiricalCoverage",
"ConstraintViolation",
"CRPS",
"IntervalWidth",
"LogLoss",
"SquaredDistrLoss",
]
Expand All @@ -25,6 +26,7 @@
CRPS,
ConstraintViolation,
EmpiricalCoverage,
IntervalWidth,
LogLoss,
PinballLoss,
SquaredDistrLoss,
Expand Down
Loading

0 comments on commit 030678d

Please sign in to comment.