Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/ptl1.6.0 #888

Merged
merged 7 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

logger = get_logger(__name__)

# Check whether we are running pytorch-lightning >= 1.6.0 or not:
tokens = pl.__version__.split(".")
pl_160_or_above = int(tokens[0]) >= 1 and int(tokens[1]) >= 6


class PLForecastingModule(pl.LightningModule, ABC):
@abstractmethod
Expand Down Expand Up @@ -324,10 +328,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

@property
def epochs_trained(self):
# trained epochs are only 0 when global step and current epoch are 0, else current epoch + 1
current_epoch = self.current_epoch
if self.current_epoch or self.global_step:

# For PTL < 1.6.0 we have to adjust:
if not pl_160_or_above and (self.current_epoch or self.global_step):
current_epoch += 1

return current_epoch


Expand Down
17 changes: 17 additions & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,16 +1299,25 @@ def save_model(self, path: str) -> None:
path
Path under which to save the model at its current state.
"""
# TODO: the parameters are saved twice currently, once with complete
# object, and once with PTL checkpointing.

raise_if_not(
path.endswith(".pth.tar"),
"The given path should end with '.pth.tar'.",
logger,
)

# We save the whole object to keep track of everything
with open(path, "wb") as f_out:
torch.save(self, f_out)

# In addition, we need to use PTL save_checkpoint() to properly save the trainer and model
if self.trainer is not None:
base_path = path[:-8]
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
self.trainer.save_checkpoint(path_ptl_ckpt)

@staticmethod
def load_model(path: str) -> "TorchForecastingModel":
"""loads a model from a given file path. The file name should end with '.pth.tar'
Expand Down Expand Up @@ -1337,6 +1346,14 @@ def load_model(path: str) -> "TorchForecastingModel":

with open(path, "rb") as fin:
model = torch.load(fin)

# If a PTL checkpoint was saved, we also need to load it:
base_path = path[:-8]
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
if os.path.exists(path_ptl_ckpt):
model.model = model.model.__class__.load_from_checkpoint(path_ptl_ckpt)
model.trainer = model.model.trainer

return model

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,22 @@ def test_manual_save_and_load(self):
checkpoint_path_manual = os.path.join(model_dir, manual_name)
os.mkdir(checkpoint_path_manual)

# save manually saved model
checkpoint_file_name = "checkpoint_0.pth.tar"
model_path_manual = os.path.join(
checkpoint_path_manual, checkpoint_file_name
)
checkpoint_file_name_cpkt = "checkpoint_0_ptl-ckpt.pth.tar"
model_path_manual_ckpt = os.path.join(
checkpoint_path_manual, checkpoint_file_name_cpkt
)

# save manually saved model
model_manual_save.save_model(model_path_manual)
self.assertTrue(os.path.exists(model_path_manual))

# check that the PTL checkpoint path is also there
self.assertTrue(os.path.exists(model_path_manual_ckpt))

# load manual save model and compare with automatic model results
model_manual_save = RNNModel.load_model(model_path_manual)
self.assertEqual(
Expand Down