Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/improved training from ckpt #1501

Merged
merged 47 commits into from
Feb 21, 2023
Merged

Conversation

madtoinou
Copy link
Collaborator

@madtoinou madtoinou commented Jan 20, 2023

Fixes #1109, #1090, #1495 and #1471.

Summary

Implement setup_finetuning; wrapping the loading of a checkpoint and the setup of the various elements related to the training of a model (trainer, optimizer, lr_scheduler) and returning a model instance which can then directly be trained using the fit method. For scenario where the training of a model is resumed from a checkpoint due to an error/crash, load_from_checkpoint remain the most efficient approach.

Other Information

The number of additional training epochs can be provided either using directly the additional_epochs argument or using the trainer_params dict (if both are provided, a sanity check is performed).
The fine-tuned model checkpoints can either be saved in the same folder as the "original one" or in a different folder (recommended to avoid overwriting the loaded checkpoint).
The fine-tuning can be chained (the _model.ckpt.tar file is created when calling setup_finetuning, after updating the trainer parameters) to give granularity to the user and avoid unexpected behaviors.

A big thank to @solalatus for providing a gist with all the attributes to update.

I tried to find methods that would allow to directly update the model.model_params attribute but it seems to be performed by PytorchLightning.

Example of use:

import numpy as np
import pandas as pd
from darts import TimeSeries
from darts.models import NHiTSModel
train_ts = TimeSeries.from_series(pd.Series(np.random.random(1000)), freq="1H")

epochs = 5
additional_epochs = 5

# first training
model_name = "original_model"
model = NHiTSModel(2,2, n_epochs=epochs, save_checkpoints=True, model_name=model_name, force_reset=True)
model.fit(train_ts)
print("first training", model.epochs_trained)

# loading the last checkpoint of the original model
model_finetune = NHiTSModel.setup_finetuning(
                                            additional_epochs=additional_epochs,
                                            old_model_name = model_name,
                                            new_model_name = model_name+"_ft",
                                            trainer_params = {"log_every_n_steps":1,
                                                              "enable_model_summary":False},
                                            )
# fine-tune the model for 5 epochs (model.trained_epochs = 10)
model_finetune.fit(train_ts)
print("first finetuning", model_finetune.epochs_trained)

# fine-tune the last checkpoint of the fine-tuned model/checkpoint
model_finetune = NHiTSModel.setup_finetuning(
                                            additional_epochs=5,
                                            old_model_name = model_name+"_ft",
                                            new_model_name = model_name+"_ft2",
                                            trainer_params = {"log_every_n_steps":1,
                                                              "enable_model_summary":False},
                                            )
# fine-tune the model for an additional 5 epochs (model.trained_epochs = 15)
model_finetune.fit(train_ts)
print("second finetuning", model_finetune.epochs_trained)

# loading the last checkpoint of the original model, save checkpoint in place (overwriting the last-epoch checkpoint)
# use the trainer_params to provide the new total number of epochs
model_finetune = NHiTSModel.setup_finetuning(
                                            old_model_name = model_name,
                                            save_inplace=True,
                                            trainer_params = {"log_every_n_steps":1,
                                                              "enable_model_summary":False,
                                                              "max_epochs":30},
                                            )
# fine-tune the model
model_finetune.fit(train_ts)
print("inplace finetuning", model_finetune.epochs_trained)

…ode, allows user to change the optimizer, scheduler or trainer and export the ckpt of this fine-tuned model into another folder. fine-tuning cannot be chained using this method (original model ckpt must be reloaded)
…er control over the logger, made the function static
…int is likely to be overwritten if the model is trained with default parameters)
@solalatus
Copy link
Contributor

Supercool! Thanks @madtoinou for making the effort! It looks way better then the "raw" version.
I will be able to test it instead of the "raw" in a couple of days. Very cool indeed!

@codecov-commenter
Copy link

codecov-commenter commented Jan 20, 2023

Codecov Report

Base: 94.06% // Head: 94.03% // Decreases project coverage by -0.04% ⚠️

Coverage data is based on head (b60c9f2) compared to base (955e2b5).
Patch coverage: 95.55% of modified lines in pull request are covered.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1501      +/-   ##
==========================================
- Coverage   94.06%   94.03%   -0.04%     
==========================================
  Files         125      125              
  Lines       11095    11123      +28     
==========================================
+ Hits        10437    10459      +22     
- Misses        658      664       +6     
Impacted Files Coverage Δ
darts/models/forecasting/block_rnn_model.py 98.24% <ø> (-0.04%) ⬇️
darts/models/forecasting/dlinear.py 100.00% <ø> (ø)
darts/models/forecasting/nbeats.py 98.11% <ø> (ø)
darts/models/forecasting/nhits.py 99.27% <ø> (-0.01%) ⬇️
darts/models/forecasting/nlinear.py 92.68% <ø> (ø)
darts/models/forecasting/rnn_model.py 97.64% <ø> (ø)
darts/models/forecasting/tcn_model.py 96.96% <ø> (ø)
darts/models/forecasting/tft_model.py 97.54% <ø> (ø)
darts/models/forecasting/transformer_model.py 100.00% <ø> (ø)
...arts/models/forecasting/torch_forecasting_model.py 89.88% <95.00%> (+0.32%) ⬆️
... and 9 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks quite good but I'm a bit hesitant about introducing a new method. Could you maybe simply improve load_from_checkpoint() and fix the epoch issue in fit() ?

def setup_finetuning(
old_model_name: str,
new_model_name: str = None,
additional_epochs: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this parameter, given that fit() already accepts epochs ?
I would find it cleaner to rely exclusively on fit()'s parameter. If there's a problem with it, could we maybe fix it there (i.e. handle the trainer correctly in fit() to handle epoch)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an issue with the fit() parameter (#1495), I think that @rijkvandermeulen is already working on a fix. I will remove the epochs argument from this method and wait for the patch to be merged.

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
@madtoinou
Copy link
Collaborator Author

Took your comments into account;

  • all the functionalities had been moved to load_from_checkpoint.
  • additional_epochs argument has been removed, it's still possible to modify the number of training epochs using the max_epochs value of pl_trainer_kwargs, it will be overwritten by the value provided by fit()
  • added the docstring

@solalatus
Copy link
Contributor

@madtoinou when I try to use your branch and load a model, I get a FileNotFoundError.

The context is: I trained the model on a machine with different home folder name now I want to load it in a new machine with new folder name (which I give in as work_dir, param), and it still looks for the model at the old place irrespective of work_dir, but only in here

I tried to trace through, and one thing is suspicious:

Are you sure that this line here should be model.work_dir insteab of just work_dir?

Am I messing something up?

Please advise!

@solalatus
Copy link
Contributor

For me, this helped:

diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py
index 0cd9602..5c57a95 100644
--- a/darts/models/forecasting/torch_forecasting_model.py
+++ b/darts/models/forecasting/torch_forecasting_model.py
@@ -1565,6 +1565,9 @@ class TorchForecastingModel(GlobalForecastingModel, ABC):
         model.pl_module_params["lr_scheduler_cls"] = model.model.lr_scheduler_cls
         model.pl_module_params["lr_scheduler_kwargs"] = model.model.lr_scheduler_kwargs
 
+        if work_dir:
+            model.work_dir=work_dir
+
         # save the initialized TorchForecastingModel to allow re-training of already re-trained model
         model.save(
             os.path.join(

Don't know if this is too crude, though!

@solalatus
Copy link
Contributor

And also - and not wanting to be horrible here - what about this line?

Shouldn't it be more like

            if "callbacks" in pl_trainer_kwargs.keys() and len(
                pl_trainer_kwargs["callbacks"] )> 0:

I mean the ( is not making sense for me in this case.

But again, I might be missing the point here...

@solalatus
Copy link
Contributor

Ok, to be more specific:

diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py
index 0cd9602..86146fe 100644
--- a/darts/models/forecasting/torch_forecasting_model.py
+++ b/darts/models/forecasting/torch_forecasting_model.py
@@ -1523,8 +1523,9 @@ class TorchForecastingModel(GlobalForecastingModel, ABC):
             model.trainer_params.update(pl_trainer_kwargs)
 
             # special parameters handling
-            if "callbacks" in pl_trainer_kwargs.keys() and len(
-                pl_trainer_kwargs["callbacks"] > 0
+            if (
+                "callbacks" in pl_trainer_kwargs.keys()
+                and len(pl_trainer_kwargs["callbacks"]) > 0
             ):
                 model.trainer_params["callbacks"] = [
                     checkpoint_callback

@solalatus
Copy link
Contributor

Ok, I am starting to become annoying. Sorry!

Some suggested refinement:

As it is now implemented here if one does not give any new lr_scheduler_cls to the load_from_checkpoint, then the loaded scheduler does not get changed. So far so good.

But in my use case, I want to finetune my model, so I definitely want to get rid of the scheduler it was loaded with. I give in an explicit None, but that does not help.

Suggestion: Adding the ability to enter delete as a string or so, and to be able to get rid of the scheduler altogether.

What do you think @madtoinou ? Any better solutions?

@solalatus
Copy link
Contributor

Additional observation: If I used a RAdam during training and I want to switch to SGD after reload, I get a KeyError about momentum. Cause: currently the optimizer's param dict is not overwritten, but appended to, so whatever I put in load_from_checkpoint, I can not get rid of the momentum param. This is the same type of issue as the one in my previous comment. Maybe plainly overwriting things as default would be a better idea?

@madtoinou
Copy link
Collaborator Author

After discussing offline with some others contributors, I decided to refactor the feature and instead of reloading all the attributes of the model's present in the checkpoint (which are relevant only when resuming training, not really for retraining/inference), the user will have to instantiate a new model and then load only the weights from the checkpoint (the method taking care of running some sanity check and initialization of the model without having to call fit_from_dataset). This should prevent the issue of "incoherent/non-updated" attribute (looking at you, learning rate scheduler and optimizer parameters) and the user will have full control on what is in the model.

It's almost done, I am currently writing the test and making sure that I am not overlooking a corner case. I'll try to update the PR in the next few days.

@solalatus
Copy link
Contributor

Cool! 🎉

I tested the branch a lot, and I think yours is the good way. Too many combinations of crazy stuff can happen.

Looking forward to the new version! I will test it as soon as I can!

@solalatus
Copy link
Contributor

One remark: This may inform your design decisions.

There would be some benefit in some "setup for fit" like functionality over and beyond just fit itself.

…tes of an existing model, rather load the weights into a new model (but not the other attributes such as the optimizer, trainer, ...
…will retrieve and copy the original .ckpt file to avoid unexpected behaviors
@madtoinou
Copy link
Collaborator Author

Since it is also related to the original purpose of this PR, I also included a fix for #1561.

Now, If a model is saved directly after loading it from a checkpoint, the original ".ckpt" checkpoint is duplicated so that this model can easily be loaded from the path given to save().

I also added a bit of logic in the load() method; if no ".ckpt" can be found, the attribute _fit_called is set to False so that if the user try to call predict, the correct exception will be raised (instead of some errors about set_predict_parameters).

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for all the updates and for fixing the other issues 👍
Had some last minor suggestions, and a fix for loading the models with the correct dtype to load identical weights and produce identical forecasts between original and loaded model.

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/forecasting_model.py Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
ckpt_hyper_params = ckpt["hyper_parameters"]

# verify that the arguments passed to the constructor match those of the checkpoint
for param_key, param_value in self.model_params.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but on the other hand this requires all TorchForecastingModels and their corresponding PLForecastingModules to share the same model parameter names, which is not the case as you mention (and might be difficult to enforce in some cases).

So the torch error can still be raised, or maybe I'm missing something :)

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really great work @madtoinou , thanks for that!

Just one last change (that I missed earlier, sorry) and then it's ready to merged!

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved

Parameters
----------
path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just saw this now. I think we should use the path that the user gave when manually saving the model, i.e. model.save("my_model.pt"), rather than the .ckpt path.

Then we just replace ".pt" with ".pt.ckpt" and get the checkpoint from there. Check here that the ckpt exists similar to how do it now in TorchForecastingModel.load()

Copy link
Collaborator Author

@madtoinou madtoinou Feb 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it makes the interface much more consistent and intuitive.

load_weights() now expects the .pt path and the .ckpt suffix is added afterward, inside the function.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great PR, everything looks good now 🚀
Feel free to merge!

@madtoinou madtoinou moved this from In progress to Done in darts Feb 21, 2023
@madtoinou madtoinou merged commit f0ea7e5 into master Feb 21, 2023
@madtoinou madtoinou deleted the feat/improved-training-from-ckpt branch February 21, 2023 14:40
@madtoinou madtoinou mentioned this pull request Mar 23, 2023
@dennisbader dennisbader moved this from Done to Released in darts May 23, 2023
alexcolpitts96 pushed a commit to alexcolpitts96/darts that referenced this pull request May 31, 2023
* feat: new function fit_from_checkpoint that load one chkpt from the mode, allows user to change the optimizer, scheduler or trainer and export the ckpt of this fine-tuned model into another folder. fine-tuning cannot be chained using this method (original model ckpt must be reloaded)

* fix: improved the model saving to allow chaining of fine-tuning, better control over the logger, made the function static

* feat: allow to save the checkpoint in the same folder (loaded checkpoint is likely to be overwritten if the model is trained with default parameters)

* fix: ordered arguments in a more intuitive way

* fix: saving model after updating all the parameters to facilitate the chain-fine tuning

* feat: support for load_from_checkpoint kwargs, support for force_reset argument

* feat: adding test for setup_finetuning

* fix: fused the setup_finetuning and load_from_checkpoint methods, added dcostring, updated tests

* fix: changed the API/approach, instead of trying to overwrite attributes of an existing model, rather load the weights into a new model (but not the other attributes such as the optimizer, trainer, ...

* fix: convertion of hyper-parameters to list when checking compatibility between checkpoint and instantiated model

* fix: skip the None attribute during the hp check

* fix: removed unecessary attribute initialization

* feat: pl_forecasting_module also save the train_sample in the checkpoints

* fix: saving only shape instead of the sample itself

* fix: restore the self.train_sample in TorchForecastingModel

* fix: update fit_called attribute to enable inference without retraining

* fix: the mock train_sample must be converted to tuple

* fix: tweaked model parameters to improve convergence

* fix: increased number of epochs to improve convergence/test stability

* fix: addressing review comments; added load_weights method and corresponding tests, updated documentation

* fix: changed default checkpoint path name for compatibility with Windows OS

* feat: raise error if the checkpoint being loaded does not contain the train_sample_shape entry, to make the break more transparent to users

* fix: saving model manually directly after laoding it from checkpoint will retrieve and copy the original .ckpt file to avoid unexpected behaviors

* fix: use random_state to fix randomness in tests

* fix: restore newlines

* fix: casting dtype of PLModule before loading the weights

* doc: model_name docstring and code were not consistent

* doc: improve phrasing

* Apply suggestions from code review

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* fix: removed warning in saving about trainer/ckpt not being found, warning will be raised in the load() call if no weights can be loaded

* fix: uniformised filename convention using '_' to separate hours, minutes and seconds, updated doc accordingly

* fix: removed typo

* Update darts/models/forecasting/torch_forecasting_model.py

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* fix: more consistent use of the path argument during save and load

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
darts
Released
5 participants