diff --git a/keypoint_detection/train/train.py b/keypoint_detection/train/train.py index 2448859..495a60a 100644 --- a/keypoint_detection/train/train.py +++ b/keypoint_detection/train/train.py @@ -90,7 +90,15 @@ def main(hparams: dict) -> Tuple[KeypointDetector, pl.Trainer]: trainer.fit(model, data_module) if "json_test_dataset_path" in hparams: - trainer.test(model, data_module) + # check if we have a best checkpoint, if not, use the current weights but log a warning + # it makes more sense to evaluate on the best checkpoint because, i.e. the best validation score obtained. + # evaluating on the current weights is more noisy and would also result in lower evaluation scores if overfitting happens + # when training longer, even with perfect i.i.d. test/val sets. This is not desired. + + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + print("No best checkpoint found, using current weights for test set evaluation") + trainer.test(model, data_module, ckpt_path="best") return model, trainer diff --git a/keypoint_detection/train/utils.py b/keypoint_detection/train/utils.py index 22ea20c..5e62718 100644 --- a/keypoint_detection/train/utils.py +++ b/keypoint_detection/train/utils.py @@ -82,6 +82,11 @@ def create_pl_trainer(hparams: dict, wandb_logger: WandbLogger) -> Trainer: ) # cf https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html + # would be better to use mAP metric for checkpointing, but this is not calculated every epoch because it is rather expensive + # (and actually this is due to the keypoint extraction from the heatmaps..) + # TODO: make this extraction faster by doing it on GPU? + + # epoch_loss still correlates rather well though checkpoint_callback = ModelCheckpoint(monitor="validation/epoch_loss", mode="min") trainer = pl.Trainer(**trainer_kwargs, callbacks=[early_stopping, checkpoint_callback])