Skip to content

Commit

Permalink
reduce amount of logged data
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Aug 25, 2023
1 parent 20c3f51 commit 2c95d65
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
25 changes: 16 additions & 9 deletions keypoint_detection/models/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def shared_step(self, batch, batch_idx, include_visualization_data_in_result_dic

def training_step(self, train_batch, batch_idx):
log_images = batch_idx == 0 and self.current_epoch > 0

result_dict = self.shared_step(train_batch, batch_idx, include_visualization_data_in_result_dict=log_images)

if log_images:
Expand Down Expand Up @@ -321,10 +322,10 @@ def validation_step(self, val_batch, batch_idx):
if self.is_ap_epoch():
self.update_ap_metrics(result_dict, self.ap_validation_metrics)

log_images = batch_idx == 0 and self.current_epoch > 0
if log_images:
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="validation")
log_images = batch_idx == 0 and self.current_epoch > 0
if log_images:
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="validation")

## log (defaults to on_epoch, which aggregates the logged values over entire validation set)
self.log("validation/epoch_loss", result_dict["loss"])
Expand All @@ -334,8 +335,10 @@ def test_step(self, test_batch, batch_idx):
# no need to switch model to eval mode, this is handled by pytorch lightning
result_dict = self.shared_step(test_batch, batch_idx, include_visualization_data_in_result_dict=True)
self.update_ap_metrics(result_dict, self.ap_test_metrics)
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="test")
# only log first 10 batches to reduce storage space
if batch_idx < 10:
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="test")
self.log("test/epoch_loss", result_dict["loss"])
self.log("test/gt_loss", result_dict["gt_loss"])

Expand Down Expand Up @@ -405,9 +408,13 @@ def compute_and_log_metrics_for_channel(

def is_ap_epoch(self) -> bool:
"""Returns True if the AP should be calculated in this epoch."""
return (
self.ap_epoch_start <= self.current_epoch and self.current_epoch % self.ap_epoch_freq == 0
) or self.current_epoch == self.trainer.max_epochs - 1
is_epch = self.ap_epoch_start <= self.current_epoch and self.current_epoch % self.ap_epoch_freq == 0
# always log the AP in the last epoch
is_epch = is_epch or self.current_epoch == self.trainer.max_epochs - 1

# if user manually specified a validation frequency, we should always log the AP in that epoch
is_epch = is_epch or (self.current_epoch > 0 and self.trainer.check_val_every_n_epoch > 1)
return is_epch

def extract_detected_keypoints_from_heatmap(self, heatmap: torch.Tensor) -> List[DetectedKeypoint]:
"""
Expand Down
4 changes: 3 additions & 1 deletion keypoint_detection/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def main(hparams: dict) -> Tuple[KeypointDetector, pl.Trainer]:
project=hparams["wandb_project"],
entity=hparams["wandb_entity"],
save_dir=get_wandb_log_dir_path(),
log_model="all", # log all checkpoints made by PL, see create_trainer for callback
log_model=True, # only log checkpoints at the end of training, i.e. only log the best checkpoint
# not suitable for expensive training runs where you might want to restart from checkpoint
# but this saves storage and usually keypoint detector training runs are not that expensive anyway
)
trainer = create_pl_trainer(hparams, wandb_logger)
trainer.fit(model, data_module)
Expand Down

0 comments on commit 2c95d65

Please sign in to comment.