Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/ptl version 200 #1651

Merged
merged 3 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Check whether we are running pytorch-lightning >= 1.6.0 or not:
tokens = pl.__version__.split(".")
pl_160_or_above = int(tokens[0]) >= 1 and int(tokens[1]) >= 6
pl_160_or_above = int(tokens[0]) > 1 or int(tokens[0]) == 1 and int(tokens[1]) >= 6


class PLForecastingModule(pl.LightningModule, ABC):
Expand Down
41 changes: 35 additions & 6 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import datetime
import inspect
import os
import re
import shutil
import sys
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -85,6 +86,10 @@

logger = get_logger(__name__)

# Check whether we are running pytorch-lightning >= 2.0.0 or not:
tokens = pl.__version__.split(".")
pl_200_or_above = int(tokens[0]) >= 2


def _get_checkpoint_folder(work_dir, model_name):
return os.path.join(work_dir, model_name, CHECKPOINTS_FOLDER)
Expand Down Expand Up @@ -427,25 +432,49 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> None:
dtype = self.train_sample[0].dtype
if np.issubdtype(dtype, np.float32):
logger.info("Time series values are 32-bits; casting model to float32.")
precision = 32
precision = "32" if not pl_200_or_above else "32-true"
elif np.issubdtype(dtype, np.float64):
logger.info("Time series values are 64-bits; casting model to float64.")
precision = 64
precision = "64" if not pl_200_or_above else "64-true"
else:
raise_log(
ValueError(
f"Invalid time series data type `{dtype}`. Cast your data to `np.float32` "
f"or `np.float64`, e.g. with `TimeSeries.astype(np.float32)`."
),
logger,
)
precision_int = int(re.findall(r"\d+", str(precision))[0])

precision_user = (
self.trainer_params.get("precision", None)
if trainer is None
else trainer.precision
)
if precision_user is not None:
# currently, we only support float 64 and 32
valid_precisions = (
["64", "32"] if not pl_200_or_above else ["64-true", "32-true"]
)
if str(precision_user) not in valid_precisions:
raise_log(
ValueError(
f"Invalid user-defined trainer_kwarg `precision={precision_user}`. "
f"Use one of ({valid_precisions})"
),
logger,
)
precision_user_int = int(re.findall(r"\d+", str(precision_user))[0])
else:
precision_user_int = None

raise_if(
precision_user is not None and int(precision_user) != precision,
f"User-defined trainer_kwarg `precision={precision_user}` does not match dtype: `{dtype}` of the "
precision_user is not None and precision_user_int != precision_int,
f"User-defined trainer_kwarg `precision='{precision_user}'` does not match dtype: `{dtype}` of the "
f"underlying TimeSeries. Set `precision` to `{precision}` or cast your data to `{precision_user}"
f"` with `TimeSeries.astype(np.float{precision_user})`.",
f"` with `TimeSeries.astype(np.float{precision_user_int})`.",
logger,
)

self.trainer_params["precision"] = precision

# we need to save the initialized TorchForecastingModel as PyTorch-Lightning only saves module checkpoints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
{
"input_chunk_length": 10,
"output_chunk_length": 5,
"n_epochs": 5,
"n_epochs": 10,
"random_state": 0,
"likelihood": GaussianLikelihood(),
},
Expand Down Expand Up @@ -168,6 +168,7 @@ def test_fit_predict_determinism(self):

def test_probabilistic_forecast_accuracy(self):
for model_cls, model_kwargs, err in models_cls_kwargs_errs:
print(model_cls)
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
self.helper_test_probabilistic_forecast_accuracy(
model_cls, model_kwargs, err, self.constant_ts, self.constant_noisy_ts
)
Expand Down
62 changes: 46 additions & 16 deletions darts/tests/models/forecasting/test_ptl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@ def test_custom_trainer_setup(self):
self.assertEqual(trainer.max_epochs, model.epochs_trained)

def test_builtin_extended_trainer(self):
invalid_trainer_kwarg = {"precisionn": 32}

# error will be raised at training time
# wrong precision parameter name
with self.assertRaises(TypeError):
invalid_trainer_kwarg = {"precisionn": "32-true"}
model = RNNModel(
12,
"RNN",
Expand All @@ -113,20 +112,51 @@ def test_builtin_extended_trainer(self):
)
model.fit(self.series, epochs=1)

valid_trainer_kwargs = {
"precision": 32,
}
# flaot 16 not supported
with self.assertRaises(ValueError):
invalid_trainer_kwarg = {"precision": "16-mixed"}
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=invalid_trainer_kwarg,
)
model.fit(self.series.astype(np.float16), epochs=1)

# valid parameters shouldn't raise error
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=valid_trainer_kwargs,
)
model.fit(self.series, epochs=1)
# precision value doesn't match `series` dtype
with self.assertRaises(ValueError):
invalid_trainer_kwarg = {"precision": "64-true"}
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=invalid_trainer_kwarg,
)
model.fit(self.series.astype(np.float32), epochs=1)

for precision, precision_int in zip(["64-true", "32-true"], [64, 32]):
valid_trainer_kwargs = {
"precision": precision,
}

# valid parameters shouldn't raise error
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=valid_trainer_kwargs,
)
ts_dtype = getattr(np, f"float{precision_int}")
model.fit(self.series.astype(ts_dtype), epochs=1)
preds = model.predict(n=3)
assert model.trainer.precision == precision
assert preds.dtype == ts_dtype

def test_custom_callback(self):
class CounterCallback(pl.callbacks.Callback):
Expand Down