From 3cbca831b972269b95113cb0fad0e79bf391acff Mon Sep 17 00:00:00 2001 From: tlpss Date: Fri, 8 Mar 2024 10:36:34 +0100 Subject: [PATCH] determine best checkpoint using mAP instead of val_loss --- keypoint_detection/models/detector.py | 6 ++++++ keypoint_detection/tasks/train_utils.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/keypoint_detection/models/detector.py b/keypoint_detection/models/detector.py index 9dce633..4e5b4a1 100644 --- a/keypoint_detection/models/detector.py +++ b/keypoint_detection/models/detector.py @@ -159,6 +159,8 @@ def __init__( # this is for later reference (e.g. checkpoint loading) and consistency. self.save_hyperparameters(ignore=["**kwargs", "backbone"]) + self._most_recent_val_mean_ap = 0.0 # used to store the most recent validation mean AP and log it in each epoch, so that checkpoint can be chosen based on this one. + def forward(self, x: torch.Tensor): """ x shape must be of shape (N,3,H,W) @@ -386,6 +388,9 @@ def log_and_reset_mean_ap(self, mode: str): self.log(f"{mode}/meanAP", mean_ap) self.log(f"{mode}/meanAP/meanAP", mean_ap) + if mode== "validation": + self._most_recent_val_mean_ap = mean_ap + def training_epoch_end(self, outputs): """ Called on the end of a training epoch. @@ -401,6 +406,7 @@ def validation_epoch_end(self, outputs): """ if self.is_ap_epoch(): self.log_and_reset_mean_ap("validation") + self.log("checkpointing_metrics/valmeanAP", self._most_recent_val_mean_ap) def test_epoch_end(self, outputs): """ diff --git a/keypoint_detection/tasks/train_utils.py b/keypoint_detection/tasks/train_utils.py index 94d8027..c869fc2 100644 --- a/keypoint_detection/tasks/train_utils.py +++ b/keypoint_detection/tasks/train_utils.py @@ -82,13 +82,13 @@ def create_pl_trainer(hparams: dict, wandb_logger: WandbLogger) -> Trainer: ) # cf https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html - # would be better to use mAP metric for checkpointing, but this is not calculated every epoch because it is rather expensive - # epoch_loss still correlates rather well though + # would be better to use mAP metric for checkpointing, but this is not calculated every epoch + # so I manually log the last known value to make the callback happy. # only store the best checkpoint and only the weights # so cannot be used to resume training but only for inference # saves storage though and training the detector is usually cheap enough to retrain it from scratch if you need specific weights etc. checkpoint_callback = ModelCheckpoint( - monitor="validation/epoch_loss", mode="min", save_weights_only=True, save_top_k=1 + monitor="checkpointing_metrics/valmeanAP", mode="max", save_weights_only=True, save_top_k=1 ) trainer = pl.Trainer(**trainer_kwargs, callbacks=[early_stopping, checkpoint_callback])