Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] initialize change cycle (0.28.0) for renaming cINNForecaster to CINNForecaster #6121

Merged
merged 4 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api_reference/forecasting.rst
Expand Up @@ -369,7 +369,7 @@ Deep learning based forecasters
:toctree: auto_generated/
:template: class.rst

CINNForecaster
cINNForecaster

.. currentmodule:: sktime.forecasting.neuralforecast

Expand Down
18 changes: 18 additions & 0 deletions sktime/forecasting/conditional_invertible_neural_network.py
Expand Up @@ -17,6 +17,7 @@
from sktime.transformations.merger import Merger
from sktime.transformations.series.fourier import FourierFeatures
from sktime.transformations.series.summarize import WindowSummarizer
from sktime.utils.warnings import warn

if _check_soft_dependencies("torch", severity="none"):
import torch
Expand All @@ -41,6 +42,7 @@ def default_sine(x, amplitude, phase, offset, amplitude2, amplitude3, phase2):
return sbase + s1 + s2


# TODO 0.29.0: rename the class cINNForecaster to CINNForecaster
class cINNForecaster(BaseDeepNetworkPyTorch):
"""
Conditional Invertible Neural Network (cINN) Forecaster.
Expand Down Expand Up @@ -177,6 +179,17 @@ def __init__(
self.val_split = val_split
super().__init__(num_epochs, batch_size, lr=lr)

warn(
"cINNForecaster will be renamed to CINNForecaster in sktime 0.29.0, "
"The estimator is available under the future name at its "
"current location, and will be available under its deprecated name "
"until 0.30.0. "
"To prepare for the name change, "
"replace cINNForecaster with CINNForecaster",
DeprecationWarning,
obj=self,
)

def _fit(self, y, fh, X=None):
"""Fit forecaster to training data.

Expand Down Expand Up @@ -607,3 +620,8 @@ def early_stop(self, validation_loss, model):
if self.counter >= self.patience:
return True
return False


# TODO 0.29.0: switch the line to cINNForecaster = CINNForecaster
# TODO 0.30.0: remove this alias altogether
CINNForecaster = cINNForecaster
19 changes: 19 additions & 0 deletions sktime/networks/cinn.py
Expand Up @@ -4,6 +4,8 @@
import numpy as np
from skbase.utils.dependencies import _check_soft_dependencies

from sktime.utils.warnings import warn

if _check_soft_dependencies("torch", severity="none"):
import torch
import torch.nn as nn
Expand All @@ -22,6 +24,7 @@ class NNModule:
import FrEIA.modules as Fm


# TODO 0.29.0: rename the class cINNNetwork to CINNNetwork
class cINNNetwork:
"""
Conditional Invertible Neural Network.
Expand Down Expand Up @@ -193,6 +196,17 @@ def __init__(
self.hidden_dim_size = hidden_dim_size
self.activation = activation if activation is not None else nn.ReLU

warn(
"cINNNetwork will be renamed to CINNNetwork in sktime 0.29.0, "
"The estimator is available under the future name at its "
"current location, and will be available under its deprecated name "
"until 0.30.0. "
"To prepare for the name change, "
"replace cINNNetwork with CINNNetwork",
DeprecationWarning,
obj=self,
)

def build(self):
"""Build the cINN."""
return self._cINNNetwork(
Expand All @@ -203,3 +217,8 @@ def build(self):
self.hidden_dim_size,
self.activation,
)


# TODO 0.29.0: switch the line to cINNNetwork = CINNNetwork
# TODO 0.30.0: remove this alias altogether
CINNNetwork = cINNNetwork