Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/ loading metrics and loss in load_from_checkpoint #1759

Merged
merged 10 commits into from
May 23, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**
- Added support for `PathLike` to the `save()` and `load()` functions of `ForecastingModel`. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).

**Fixed**
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).
Expand Down
38 changes: 23 additions & 15 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def __init__(
super().__init__()

# save hyper parameters for saving/loading
# do not save type nn.Module params
self.save_hyperparameters(ignore=["loss_fn", "torch_metrics"])
self.save_hyperparameters()

raise_if(
input_chunk_length is None or output_chunk_length is None,
Expand All @@ -116,19 +115,8 @@ def __init__(
dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs
)

if torch_metrics is None:
torch_metrics = torchmetrics.MetricCollection([])
elif isinstance(torch_metrics, torchmetrics.Metric):
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
pass
else:
raise_log(
AttributeError(
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
),
logger,
)
# convert torch_metrics to torchmetrics.MetricCollection
torch_metrics = self.configure_torch_metrics(torch_metrics)
self.train_metrics = torch_metrics.clone(prefix="train_")
self.val_metrics = torch_metrics.clone(prefix="val_")

Expand Down Expand Up @@ -425,6 +413,26 @@ def epochs_trained(self):

return current_epoch

@staticmethod
def configure_torch_metrics(
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
torch_metrics: Union[torchmetrics.Metric, torchmetrics.MetricCollection]
) -> torchmetrics.MetricCollection:
"""process the torch_metrics parameter."""
if torch_metrics is None:
torch_metrics = torchmetrics.MetricCollection([])
elif isinstance(torch_metrics, torchmetrics.Metric):
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
pass
else:
raise_log(
AttributeError(
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
),
logger,
)
return torch_metrics


class PLPastCovariatesModule(PLForecastingModule, ABC):
def _produce_train_output(self, input_batch: Tuple):
Expand Down
1 change: 1 addition & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,7 @@ def load_from_checkpoint(
logger.info(f"loading {file_name}")

model.model = model._load_from_checkpoint(file_path, **kwargs)

# restore _fit_called attribute, set to False in load() if no .ckpt is found/provided
model._fit_called = True
model.load_ckpt_path = file_path
Expand Down
89 changes: 86 additions & 3 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

try:
import torch
from pytorch_lightning.loggers.logger import DummyLogger
from pytorch_lightning.tuner.lr_finder import _LRFinder
from torchmetrics import (
MeanAbsoluteError,
Expand Down Expand Up @@ -471,6 +472,63 @@ def test_load_weights(self):
f"respectively {retrained_mape} and {original_mape}",
)

def test_load_from_checkpoint_w_custom_loss(self):
model_name = "pretraining_custom_loss"
# model with a custom loss
model = RNNModel(
12,
"RNN",
5,
1,
n_epochs=1,
work_dir=self.temp_work_dir,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
loss_fn=torch.nn.L1Loss(),
)
model.fit(self.series)

loaded_model = RNNModel.load_from_checkpoint(
model_name, self.temp_work_dir, best=False
)
# custom loss function should be properly restored from ckpt
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))

loaded_model.fit(self.series, epochs=2)
# calling fit() should not impact the loss function
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))

def test_load_from_checkpoint_w_metrics(self):
model_name = "pretraining_metrics"
# model with one torch_metrics
model = RNNModel(
12,
"RNN",
5,
1,
n_epochs=1,
work_dir=self.temp_work_dir,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
torch_metrics=MeanAbsolutePercentageError(),
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
)
model.fit(self.series)
# check train_metrics before loading
self.assertTrue(isinstance(model.model.train_metrics, MetricCollection))
self.assertEqual(len(model.model.train_metrics), 1)

loaded_model = RNNModel.load_from_checkpoint(
model_name, self.temp_work_dir, best=False
)
# custom loss function should be properly restored from ckpt torchmetrics.Metric
self.assertTrue(
isinstance(loaded_model.model.train_metrics, MetricCollection)
)
self.assertEqual(len(loaded_model.model.train_metrics), 1)

def test_optimizers(self):

optimizers = [
Expand Down Expand Up @@ -531,17 +589,39 @@ def test_metrics(self):
)

# test single metric
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

# test metric collection
model = RNNModel(
12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric_collection
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

# test multivariate series
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.multivariate_series)

def test_metrics_w_likelihood(self):
Expand All @@ -559,6 +639,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

Expand All @@ -571,6 +652,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

Expand All @@ -583,6 +665,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.multivariate_series)

Expand Down