Skip to content

Commit

Permalink
determine best checkpoint using mAP instead of val_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Mar 8, 2024
1 parent 21bf1e7 commit 3cbca83
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 6 additions & 0 deletions keypoint_detection/models/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def __init__(
# this is for later reference (e.g. checkpoint loading) and consistency.
self.save_hyperparameters(ignore=["**kwargs", "backbone"])

self._most_recent_val_mean_ap = 0.0 # used to store the most recent validation mean AP and log it in each epoch, so that checkpoint can be chosen based on this one.

def forward(self, x: torch.Tensor):
"""
x shape must be of shape (N,3,H,W)
Expand Down Expand Up @@ -386,6 +388,9 @@ def log_and_reset_mean_ap(self, mode: str):
self.log(f"{mode}/meanAP", mean_ap)
self.log(f"{mode}/meanAP/meanAP", mean_ap)

if mode== "validation":
self._most_recent_val_mean_ap = mean_ap

def training_epoch_end(self, outputs):
"""
Called on the end of a training epoch.
Expand All @@ -401,6 +406,7 @@ def validation_epoch_end(self, outputs):
"""
if self.is_ap_epoch():
self.log_and_reset_mean_ap("validation")
self.log("checkpointing_metrics/valmeanAP", self._most_recent_val_mean_ap)

def test_epoch_end(self, outputs):
"""
Expand Down
6 changes: 3 additions & 3 deletions keypoint_detection/tasks/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def create_pl_trainer(hparams: dict, wandb_logger: WandbLogger) -> Trainer:
)
# cf https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html

# would be better to use mAP metric for checkpointing, but this is not calculated every epoch because it is rather expensive
# epoch_loss still correlates rather well though
# would be better to use mAP metric for checkpointing, but this is not calculated every epoch
# so I manually log the last known value to make the callback happy.
# only store the best checkpoint and only the weights
# so cannot be used to resume training but only for inference
# saves storage though and training the detector is usually cheap enough to retrain it from scratch if you need specific weights etc.
checkpoint_callback = ModelCheckpoint(
monitor="validation/epoch_loss", mode="min", save_weights_only=True, save_top_k=1
monitor="checkpointing_metrics/valmeanAP", mode="max", save_weights_only=True, save_top_k=1
)

trainer = pl.Trainer(**trainer_kwargs, callbacks=[early_stopping, checkpoint_callback])
Expand Down

0 comments on commit 3cbca83

Please sign in to comment.