Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/ loading metrics and loss in load_from_checkpoint #1759

Merged
merged 10 commits into from
May 23, 2023
56 changes: 43 additions & 13 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,8 @@ def __init__(
dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs
)

if torch_metrics is None:
torch_metrics = torchmetrics.MetricCollection([])
elif isinstance(torch_metrics, torchmetrics.Metric):
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
pass
else:
raise_log(
AttributeError(
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
),
logger,
)
# convert torch_metrics to torchmetrics.MetricCollection
torch_metrics = self.configure_torch_metrics(torch_metrics)
self.train_metrics = torch_metrics.clone(prefix="train_")
self.val_metrics = torch_metrics.clone(prefix="val_")

Expand Down Expand Up @@ -392,13 +381,34 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["model_dtype"] = self.dtype
# we must save the shape of the input to be able to instanciate the model without calling fit_from_dataset
checkpoint["train_sample_shape"] = self.train_sample_shape
# we must save the loss to properly restore it when resuming training
checkpoint["loss_fn"] = self.criterion
# we must save the metrics to continue outputing them when resuming training
checkpoint["torch_metrics_train"] = self.train_metrics
checkpoint["torch_metrics_val"] = self.val_metrics

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# by default our models are initialized as float32. For other dtypes, we need to cast to the correct precision
# before parameters are loaded by PyTorch-Lightning
dtype = checkpoint["model_dtype"]
self.to_dtype(dtype)

# restoring attributes necessary to resume from training properly
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I just saw that we don't load the "train_sample_shape" from checkpoint. I think we should add this here as well, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, it's already loaded when calling load_weights_from_checkpoint(). My guess is that since it's one of the constructor argument and that it does not require any processing, the de-serializing of the checkpoint by Pytorch Lightning does the job.

if (
"loss_fn" in checkpoint.keys()
and "torch_metrics_train" in checkpoint.keys()
):
self.criterion = checkpoint["loss_fn"]
self.train_metrics = checkpoint["torch_metrics_train"]
self.val_metrics = checkpoint["torch_metrics_val"]
else:
# explicitly indicate to the user that there is a bug
logger.warning(
"This checkpoint was generated with darts <= 0.24.0, if a custom loss "
"was used to train the model, it won't be properly loaded. Similarly, "
"the torch metrics won't be restored from the checkpoint."
)

def to_dtype(self, dtype):
"""Cast module precision (float32 by default) to another precision."""
if dtype == torch.float16:
Expand All @@ -425,6 +435,26 @@ def epochs_trained(self):

return current_epoch

@staticmethod
def configure_torch_metrics(
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
torch_metrics: Union[torchmetrics.Metric, torchmetrics.MetricCollection]
) -> torchmetrics.MetricCollection:
"""process the torch_metrics parameter."""
if torch_metrics is None:
torch_metrics = torchmetrics.MetricCollection([])
elif isinstance(torch_metrics, torchmetrics.Metric):
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
pass
else:
raise_log(
AttributeError(
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
),
logger,
)
return torch_metrics


class PLPastCovariatesModule(PLForecastingModule, ABC):
def _produce_train_output(self, input_batch: Tuple):
Expand Down
1 change: 1 addition & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,7 @@ def load_from_checkpoint(
logger.info(f"loading {file_name}")

model.model = model._load_from_checkpoint(file_path, **kwargs)

# restore _fit_called attribute, set to False in load() if no .ckpt is found/provided
model._fit_called = True
model.load_ckpt_path = file_path
Expand Down
89 changes: 86 additions & 3 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

try:
import torch
from pytorch_lightning.loggers.logger import DummyLogger
from pytorch_lightning.tuner.lr_finder import _LRFinder
from torchmetrics import (
MeanAbsoluteError,
Expand Down Expand Up @@ -471,6 +472,63 @@ def test_load_weights(self):
f"respectively {retrained_mape} and {original_mape}",
)

def test_load_from_checkpoint_w_custom_loss(self):
model_name = "pretraining_custom_loss"
# model with a custom loss
model = RNNModel(
12,
"RNN",
5,
1,
n_epochs=1,
work_dir=self.temp_work_dir,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
loss_fn=torch.nn.L1Loss(),
)
model.fit(self.series)

loaded_model = RNNModel.load_from_checkpoint(
model_name, self.temp_work_dir, best=False
)
# custom loss function should be properly restored from ckpt
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))

loaded_model.fit(self.series, epochs=2)
# calling fit() should not impact the loss function
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))

def test_load_from_checkpoint_w_metrics(self):
model_name = "pretraining_metrics"
# model with one torch_metrics
model = RNNModel(
12,
"RNN",
5,
1,
n_epochs=1,
work_dir=self.temp_work_dir,
model_name=model_name,
save_checkpoints=True,
force_reset=True,
torch_metrics=MeanAbsolutePercentageError(),
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
)
model.fit(self.series)
# check train_metrics before loading
self.assertTrue(isinstance(model.model.train_metrics, MetricCollection))
self.assertEqual(len(model.model.train_metrics), 1)

loaded_model = RNNModel.load_from_checkpoint(
model_name, self.temp_work_dir, best=False
)
# custom loss function should be properly restored from ckpt torchmetrics.Metric
self.assertTrue(
isinstance(loaded_model.model.train_metrics, MetricCollection)
)
self.assertEqual(len(loaded_model.model.train_metrics), 1)

def test_optimizers(self):

optimizers = [
Expand Down Expand Up @@ -531,17 +589,39 @@ def test_metrics(self):
)

# test single metric
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

# test metric collection
model = RNNModel(
12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric_collection
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

# test multivariate series
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.multivariate_series)

def test_metrics_w_likelihood(self):
Expand All @@ -559,6 +639,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

Expand All @@ -571,6 +652,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.series)

Expand All @@ -583,6 +665,7 @@ def test_metrics_w_likelihood(self):
n_epochs=1,
likelihood=GaussianLikelihood(),
torch_metrics=metric_collection,
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
)
model.fit(self.multivariate_series)

Expand Down