-
-
Notifications
You must be signed in to change notification settings - Fork 657
Description
I'm using ignite 0.2.1 similar to the transfer-learning-conv-ai repo by Hugging Face. In these lines, you can see that:
- the checkpoint is being saved for every epoch
- just the last three saved checkpoints are being retained on disk
- the last checkpoint (due to
_saved[-1]
) is being renamed to be the final trained model
In my code, I'm additionally using the EarlyStopping
class with a configurable patience like this:
valid_es_handler = EarlyStopping(patience=args.patience, score_function=early_stopping_score_function,
trainer=trainer)
validator.add_event_handler(Events.COMPLETED, valid_es_handler)
Now what I want to accomplish is this: I want to identify and rename the best (in terms of validation set score) trained model from the window of stored checkpoints.
I think the first change that needs to be done is n_saved=args.patience
from n_saved=3
, so that the window of saved checkpoints is equal to the patience used for early stopping.
Consequently, it looks like I need to provide the same early_stopping_score_function
also to ModelCheckpoint
using the score_function
arg, and that would create a score-based priority queue of saved checkpoints.
And with those changes, it looks like _saved[-1]
would still point to the "best" model checkpoint in the window. Is my understanding of the changes correct?
Also, I haven't looked at the newer versions of ignite after 0.2.1, but could you please share what the breaking changes are (using the above linked code as an example)? I might consider upgrading to the latest ignite if the changes needed are minimal.
The other thing I don't understand is this - the score function would be called on the engine
, but for our use-case, this engine
should be the validator
(for both EarlyStopping
and ModelCheckpoint
), right?
But this line in the transfer-learning-conv-ai repo:
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation
will end up making the score function call on the trainer
Engine if I understand correctly. How do I ensure that the validator
is used for the score function in the checkpoint_handler
?