Skip to content

Commit

Permalink
Fix/ptl1.6.0 (#888)
Browse files Browse the repository at this point in the history
* fix epochs trained count

* save PTL module and trainer using PTL checkpointing

* dynamically compute right number of epochs trained

* test checkpoint file existence

* restore model saving in tests
  • Loading branch information
hrzn committed Apr 5, 2022
1 parent fb5a59e commit efa955a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
10 changes: 8 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

logger = get_logger(__name__)

# Check whether we are running pytorch-lightning >= 1.6.0 or not:
tokens = pl.__version__.split(".")
pl_160_or_above = int(tokens[0]) >= 1 and int(tokens[1]) >= 6


class PLForecastingModule(pl.LightningModule, ABC):
@abstractmethod
Expand Down Expand Up @@ -324,10 +328,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

@property
def epochs_trained(self):
# trained epochs are only 0 when global step and current epoch are 0, else current epoch + 1
current_epoch = self.current_epoch
if self.current_epoch or self.global_step:

# For PTL < 1.6.0 we have to adjust:
if not pl_160_or_above and (self.current_epoch or self.global_step):
current_epoch += 1

return current_epoch


Expand Down
17 changes: 17 additions & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,16 +1299,25 @@ def save_model(self, path: str) -> None:
path
Path under which to save the model at its current state.
"""
# TODO: the parameters are saved twice currently, once with complete
# object, and once with PTL checkpointing.

raise_if_not(
path.endswith(".pth.tar"),
"The given path should end with '.pth.tar'.",
logger,
)

# We save the whole object to keep track of everything
with open(path, "wb") as f_out:
torch.save(self, f_out)

# In addition, we need to use PTL save_checkpoint() to properly save the trainer and model
if self.trainer is not None:
base_path = path[:-8]
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
self.trainer.save_checkpoint(path_ptl_ckpt)

@staticmethod
def load_model(path: str) -> "TorchForecastingModel":
"""loads a model from a given file path. The file name should end with '.pth.tar'
Expand Down Expand Up @@ -1337,6 +1346,14 @@ def load_model(path: str) -> "TorchForecastingModel":

with open(path, "rb") as fin:
model = torch.load(fin)

# If a PTL checkpoint was saved, we also need to load it:
base_path = path[:-8]
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
if os.path.exists(path_ptl_ckpt):
model.model = model.model.__class__.load_from_checkpoint(path_ptl_ckpt)
model.trainer = model.model.trainer

return model

@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,22 @@ def test_manual_save_and_load(self):
checkpoint_path_manual = os.path.join(model_dir, manual_name)
os.mkdir(checkpoint_path_manual)

# save manually saved model
checkpoint_file_name = "checkpoint_0.pth.tar"
model_path_manual = os.path.join(
checkpoint_path_manual, checkpoint_file_name
)
checkpoint_file_name_cpkt = "checkpoint_0_ptl-ckpt.pth.tar"
model_path_manual_ckpt = os.path.join(
checkpoint_path_manual, checkpoint_file_name_cpkt
)

# save manually saved model
model_manual_save.save_model(model_path_manual)
self.assertTrue(os.path.exists(model_path_manual))

# check that the PTL checkpoint path is also there
self.assertTrue(os.path.exists(model_path_manual_ckpt))

# load manual save model and compare with automatic model results
model_manual_save = RNNModel.load_model(model_path_manual)
self.assertEqual(
Expand Down

0 comments on commit efa955a

Please sign in to comment.