diff --git a/docs/source/api_reference/forecasting.rst b/docs/source/api_reference/forecasting.rst index d89ea3104bb..bcd8b1bbcfd 100644 --- a/docs/source/api_reference/forecasting.rst +++ b/docs/source/api_reference/forecasting.rst @@ -369,7 +369,7 @@ Deep learning based forecasters :toctree: auto_generated/ :template: class.rst - CINNForecaster + cINNForecaster .. currentmodule:: sktime.forecasting.neuralforecast diff --git a/sktime/forecasting/conditional_invertible_neural_network.py b/sktime/forecasting/conditional_invertible_neural_network.py index a6279c7614e..cf66baaed34 100644 --- a/sktime/forecasting/conditional_invertible_neural_network.py +++ b/sktime/forecasting/conditional_invertible_neural_network.py @@ -17,6 +17,7 @@ from sktime.transformations.merger import Merger from sktime.transformations.series.fourier import FourierFeatures from sktime.transformations.series.summarize import WindowSummarizer +from sktime.utils.warnings import warn if _check_soft_dependencies("torch", severity="none"): import torch @@ -41,6 +42,7 @@ def default_sine(x, amplitude, phase, offset, amplitude2, amplitude3, phase2): return sbase + s1 + s2 +# TODO 0.29.0: rename the class cINNForecaster to CINNForecaster class cINNForecaster(BaseDeepNetworkPyTorch): """ Conditional Invertible Neural Network (cINN) Forecaster. @@ -177,6 +179,17 @@ def __init__( self.val_split = val_split super().__init__(num_epochs, batch_size, lr=lr) + warn( + "cINNForecaster will be renamed to CINNForecaster in sktime 0.29.0, " + "The estimator is available under the future name at its " + "current location, and will be available under its deprecated name " + "until 0.30.0. " + "To prepare for the name change, " + "replace cINNForecaster with CINNForecaster", + DeprecationWarning, + obj=self, + ) + def _fit(self, y, fh, X=None): """Fit forecaster to training data. @@ -607,3 +620,8 @@ def early_stop(self, validation_loss, model): if self.counter >= self.patience: return True return False + + +# TODO 0.29.0: switch the line to cINNForecaster = CINNForecaster +# TODO 0.30.0: remove this alias altogether +CINNForecaster = cINNForecaster diff --git a/sktime/networks/cinn.py b/sktime/networks/cinn.py index 2b6be944f84..c2438faf07d 100644 --- a/sktime/networks/cinn.py +++ b/sktime/networks/cinn.py @@ -4,6 +4,8 @@ import numpy as np from skbase.utils.dependencies import _check_soft_dependencies +from sktime.utils.warnings import warn + if _check_soft_dependencies("torch", severity="none"): import torch import torch.nn as nn @@ -22,6 +24,7 @@ class NNModule: import FrEIA.modules as Fm +# TODO 0.29.0: rename the class cINNNetwork to CINNNetwork class cINNNetwork: """ Conditional Invertible Neural Network. @@ -193,6 +196,17 @@ def __init__( self.hidden_dim_size = hidden_dim_size self.activation = activation if activation is not None else nn.ReLU + warn( + "cINNNetwork will be renamed to CINNNetwork in sktime 0.29.0, " + "The estimator is available under the future name at its " + "current location, and will be available under its deprecated name " + "until 0.30.0. " + "To prepare for the name change, " + "replace cINNNetwork with CINNNetwork", + DeprecationWarning, + obj=self, + ) + def build(self): """Build the cINN.""" return self._cINNNetwork( @@ -203,3 +217,8 @@ def build(self): self.hidden_dim_size, self.activation, ) + + +# TODO 0.29.0: switch the line to cINNNetwork = CINNNetwork +# TODO 0.30.0: remove this alias altogether +CINNNetwork = cINNNetwork