Skip to content

Commit

Permalink
Fix/ loading metrics and loss in load_from_checkpoint (#1759)
Browse files Browse the repository at this point in the history
* fix: loss_fn and torch_metrics are properly restored when calling laoding_from_checkpoint()

* fix: moved fix to the PL on_save/on_load methods instead of load_from_checkpoint()

* fix: address reviewer comments, loss and metrics objects are saved in the constructor

* update changelog

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
  • Loading branch information
madtoinou and dennisbader committed May 23, 2023
1 parent 5b68b69 commit 31da6d3
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**
- Added support for `PathLike` to the `save()` and `load()` functions of `ForecastingModel`. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).

**Fixed**
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).
Expand Down
38 changes: 23 additions & 15 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def __init__(
super().__init__()

# save hyper parameters for saving/loading
# do not save type nn.Module params
self.save_hyperparameters(ignore=["loss_fn", "torch_metrics"])
self.save_hyperparameters()

raise_if(
input_chunk_length is None or output_chunk_length is None,
Expand All @@ -116,19 +115,8 @@ def __init__(
dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs
)

if torch_metrics is None:
torch_metrics = torchmetrics.MetricCollection([])
elif isinstance(torch_metrics, torchmetrics.Metric):
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
pass
else:
raise_log(
AttributeError(
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
),
logger,
)
# convert torch_metrics to torchmetrics.MetricCollection
torch_metrics = self.configure_torch_metrics(torch_metrics)
self.train_metrics = torch_metrics.clone(prefix="train_")
self.val_metrics = torch_metrics.clone(prefix="val_")

Expand Down Expand Up @@ -425,6 +413,26 @@ def epochs_trained(self):

return current_epoch

@staticmethod
def configure_torch_metrics(
torch_metrics: Union[torchmetrics.Metric, torchmetrics.MetricCollection]
) -> torchmetrics.MetricCollection:
"""process the torch_metrics parameter."""
if torch_metrics is None:
torch_metrics = torchmetrics.MetricCollection([])
elif isinstance(torch_metrics, torchmetrics.Metric):
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
pass
else:
raise_log(
AttributeError(
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
),
logger,
)
return torch_metrics


class PLPastCovariatesModule(PLForecastingModule, ABC):
def _produce_train_output(self, input_batch: Tuple):
Expand Down
1 change: 1 addition & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,7 @@ def load_from_checkpoint(
logger.info(f"loading {file_name}")

model.model = model._load_from_checkpoint(file_path, **kwargs)

# restore _fit_called attribute, set to False in load() if no .ckpt is found/provided
model._fit_called = True
model.load_ckpt_path = file_path
Expand Down
89 changes: 86 additions & 3 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

try:
import torch
from pytorch_lightning.loggers.logger import DummyLogger
from pytorch_lightning.tuner.lr_finder import _LRFinder
from torchmetrics import (
MeanAbsoluteError,
Expand Down Expand Up @@ -471,6 +472,63 @@ def test_load_weights(self):
f"respectively {retrained_mape} and {original_mape}",
)

def test_load_from_checkpoint_w_custom_loss(self):
model_name = "pretraining_custom_loss"
# model with a custom loss
model = RNNModel(
12,
"RNN",
5,
1,
n_epochs=1,
work_dir=self.temp_work_dir,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
loss_fn=torch.nn.L1Loss(),
)
model.fit(self.series)

loaded_model = RNNModel.load_from_checkpoint(
model_name, self.temp_work_dir, best=False
)
# custom loss function should be properly restored from ckpt
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))

loaded_model.fit(self.series, epochs=2)
# calling fit() should not impact the loss function
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))

def test_load_from_checkpoint_w_metrics(self):
model_name = "pretraining_metrics"
# model with one torch_metrics
model = RNNModel(
12,
"RNN",
5,
1,
n_epochs=1,
work_dir=self.temp_work_dir,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
torch_metrics=MeanAbsolutePercentageError(),
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)
# check train_metrics before loading
self.assertTrue(isinstance(model.model.train_metrics, MetricCollection))
self.assertEqual(len(model.model.train_metrics), 1)

loaded_model = RNNModel.load_from_checkpoint(
model_name, self.temp_work_dir, best=False
)
# custom loss function should be properly restored from ckpt torchmetrics.Metric
self.assertTrue(
isinstance(loaded_model.model.train_metrics, MetricCollection)
)
self.assertEqual(len(loaded_model.model.train_metrics), 1)

def test_optimizers(self):

optimizers = [
Expand Down Expand Up @@ -531,17 +589,39 @@ def test_metrics(self):
)

# test single metric
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

# test metric collection
model = RNNModel(
12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric_collection
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

# test multivariate series
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.multivariate_series)

def test_metrics_w_likelihood(self):
Expand All @@ -559,6 +639,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

Expand All @@ -571,6 +652,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

Expand All @@ -583,6 +665,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.multivariate_series)

Expand Down

0 comments on commit 31da6d3

Please sign in to comment.