Skip to content

Commit

Permalink
Update get_test_params
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavvp16 committed Apr 13, 2024
1 parent 757b9c7 commit 74e3ff8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 59 deletions.
8 changes: 6 additions & 2 deletions sktime/forecasting/base/adapters/_neuralforecast.py
Expand Up @@ -177,7 +177,9 @@ def _get_valid_parameters(self: "_NeuralForecastAdapter") -> dict:
f" not found in the __init__ method "
f"from {Trainer}. "
f"Check your pytorch_lightning version "
f"to find out the right API parameters."
f"to find out the right API parameters.",
obj=self,
stacklevel=2,
)
filter_params["trainer_kwargs"].pop(invalid_param)

Expand All @@ -188,7 +190,9 @@ def _get_valid_parameters(self: "_NeuralForecastAdapter") -> dict:
f" not found in the __init__ method "
f"from {self.algorithm_class}. "
f"Check your neuralforecast version "
f"to find out the right API parameters."
f"to find out the right API parameters.",
obj=self,
stacklevel=2,
)
filter_params.pop(unsupported_param)

Expand Down
7 changes: 7 additions & 0 deletions sktime/forecasting/neuralforecast.py
Expand Up @@ -358,6 +358,7 @@ def get_test_params(cls, parameter_set="default"):
]
else:
from neuralforecast.losses.pytorch import SMAPE, QuantileLoss
from torch.optim import Adam

params = [
{
Expand All @@ -378,6 +379,8 @@ def get_test_params(cls, parameter_set="default"):
"max_steps": 4,
"val_check_steps": 2,
"trainer_kwargs": {"logger": False},
"optimizer": Adam,
"optimizer_kwargs": {"lr": 0.001},
},
]

Expand Down Expand Up @@ -696,6 +699,7 @@ def get_test_params(cls, parameter_set="default"):

try:
_check_soft_dependencies("neuralforecast", severity="error")
_check_soft_dependencies("torch", severity="error")
except ModuleNotFoundError:
params = [
{
Expand All @@ -718,6 +722,7 @@ def get_test_params(cls, parameter_set="default"):
]
else:
from neuralforecast.losses.pytorch import SMAPE, QuantileLoss
from torch.optim import Adam

params = [
{
Expand All @@ -738,6 +743,8 @@ def get_test_params(cls, parameter_set="default"):
"max_steps": 4,
"val_check_steps": 2,
"trainer_kwargs": {"logger": False},
"optimizer": Adam,
"optimizer_kwargs": {"lr": 0.001},
},
]

Expand Down
57 changes: 0 additions & 57 deletions sktime/forecasting/tests/test_neuralforecast.py
Expand Up @@ -322,60 +322,3 @@ def test_neural_forecast_with_auto_freq_on_missing_date_like(
ValueError, match="(could not interpret freq).*(use a valid offset in index)"
):
model.fit(y, fh=[1, 2, 3])


@pytest.mark.parametrize("model_class", [NeuralForecastLSTM, NeuralForecastRNN])
@pytest.mark.skipif(
not run_test_for_class([NeuralForecastLSTM, NeuralForecastRNN]),
reason="run test only if softdeps are present and incrementally (if requested)",
)
def test_neural_forecast_with_non_default_optimizer(model_class) -> None:
"""Test with user defined optimizer."""
# import non-default pytorch optimizer
from torch.optim import Adam

# define model
model = model_class(
freq="A-DEC",
max_steps=5,
optimizer=Adam,
trainer_kwargs={"logger": False},
)

# train model
model.fit(X_train, fh=[1, 2, 3, 4])

# predict with trained model
X_pred = model.predict()

# check prediction index
pandas.testing.assert_index_equal(X_pred.index, X_test.index, check_names=False)


@pytest.mark.parametrize("model_class", [NeuralForecastLSTM, NeuralForecastRNN])
@pytest.mark.skipif(
not run_test_for_class([NeuralForecastLSTM, NeuralForecastRNN]),
reason="run test only if softdeps are present and incrementally (if requested)",
)
def test_neural_forecast_with_non_default_optimizer_with_kwargs(model_class) -> None:
"""Test with user defined optimizer and optimizer_kwargs."""
# import non-default pytorch optimizer
from torch.optim import Adagrad

# define model
model = model_class(
freq="A-DEC",
optimizer=Adagrad,
optimizer_kwargs={"lr": 0.1},
max_steps=5,
trainer_kwargs={"logger": False},
)

# train model
model.fit(X_train, fh=[1, 2, 3, 4])

# predict with trained model
X_pred = model.predict()

# check prediction index
pandas.testing.assert_index_equal(X_pred.index, X_test.index, check_names=False)

0 comments on commit 74e3ff8

Please sign in to comment.