Skip to content

Commit

Permalink
[MNT] initialize change cycle (0.28.0) for renaming cINNForecaster
Browse files Browse the repository at this point in the history
…to `CINNForecaster` (#6121)

#### Reference Issues/PRs
This PR prepares the change cycle for renaming `cINNForecaster` to
`CINNForecaster` completing the steps for release v0.28.0 in #6120

#### What does this implement/fix? Explain your changes.
- todos on top of class definition
- warning message in class initialization
- alias line added with a todo in the bottom of the file
  • Loading branch information
geetu040 committed Mar 22, 2024
1 parent b667111 commit d96272d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/api_reference/forecasting.rst
Expand Up @@ -369,7 +369,7 @@ Deep learning based forecasters
:toctree: auto_generated/
:template: class.rst

CINNForecaster
cINNForecaster

.. currentmodule:: sktime.forecasting.neuralforecast

Expand Down
18 changes: 18 additions & 0 deletions sktime/forecasting/conditional_invertible_neural_network.py
Expand Up @@ -17,6 +17,7 @@
from sktime.transformations.merger import Merger
from sktime.transformations.series.fourier import FourierFeatures
from sktime.transformations.series.summarize import WindowSummarizer
from sktime.utils.warnings import warn

if _check_soft_dependencies("torch", severity="none"):
import torch
Expand All @@ -41,6 +42,7 @@ def default_sine(x, amplitude, phase, offset, amplitude2, amplitude3, phase2):
return sbase + s1 + s2


# TODO 0.29.0: rename the class cINNForecaster to CINNForecaster
class cINNForecaster(BaseDeepNetworkPyTorch):
"""
Conditional Invertible Neural Network (cINN) Forecaster.
Expand Down Expand Up @@ -177,6 +179,17 @@ def __init__(
self.val_split = val_split
super().__init__(num_epochs, batch_size, lr=lr)

warn(
"cINNForecaster will be renamed to CINNForecaster in sktime 0.29.0, "
"The estimator is available under the future name at its "
"current location, and will be available under its deprecated name "
"until 0.30.0. "
"To prepare for the name change, "
"replace cINNForecaster with CINNForecaster",
DeprecationWarning,
obj=self,
)

def _fit(self, y, fh, X=None):
"""Fit forecaster to training data.
Expand Down Expand Up @@ -607,3 +620,8 @@ def early_stop(self, validation_loss, model):
if self.counter >= self.patience:
return True
return False


# TODO 0.29.0: switch the line to cINNForecaster = CINNForecaster
# TODO 0.30.0: remove this alias altogether
CINNForecaster = cINNForecaster
19 changes: 19 additions & 0 deletions sktime/networks/cinn.py
Expand Up @@ -4,6 +4,8 @@
import numpy as np
from skbase.utils.dependencies import _check_soft_dependencies

from sktime.utils.warnings import warn

if _check_soft_dependencies("torch", severity="none"):
import torch
import torch.nn as nn
Expand All @@ -22,6 +24,7 @@ class NNModule:
import FrEIA.modules as Fm


# TODO 0.29.0: rename the class cINNNetwork to CINNNetwork
class cINNNetwork:
"""
Conditional Invertible Neural Network.
Expand Down Expand Up @@ -193,6 +196,17 @@ def __init__(
self.hidden_dim_size = hidden_dim_size
self.activation = activation if activation is not None else nn.ReLU

warn(
"cINNNetwork will be renamed to CINNNetwork in sktime 0.29.0, "
"The estimator is available under the future name at its "
"current location, and will be available under its deprecated name "
"until 0.30.0. "
"To prepare for the name change, "
"replace cINNNetwork with CINNNetwork",
DeprecationWarning,
obj=self,
)

def build(self):
"""Build the cINN."""
return self._cINNNetwork(
Expand All @@ -203,3 +217,8 @@ def build(self):
self.hidden_dim_size,
self.activation,
)


# TODO 0.29.0: switch the line to cINNNetwork = CINNNetwork
# TODO 0.30.0: remove this alias altogether
CINNNetwork = cINNNetwork

0 comments on commit d96272d

Please sign in to comment.