Skip to content

Commit

Permalink
calculate & log AP for training
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Aug 30, 2023
1 parent d73dbaf commit 65be243
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions keypoint_detection/models/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def __init__(
]
self.maximal_gt_keypoint_pixel_distances = maximal_gt_keypoint_pixel_distances

self.ap_training_metrics = [
KeypointAPMetrics(self.maximal_gt_keypoint_pixel_distances) for _ in self.keypoint_channel_configuration
]
self.ap_validation_metrics = [
KeypointAPMetrics(self.maximal_gt_keypoint_pixel_distances) for _ in self.keypoint_channel_configuration
]
Expand Down Expand Up @@ -251,8 +254,17 @@ def shared_step(self, batch, batch_idx, include_visualization_data_in_result_dic

def training_step(self, train_batch, batch_idx):
log_images = batch_idx == 0 and self.current_epoch > 0
should_log_ap = (
self.is_ap_epoch()
) # and batch_idx < 20 # limit AP calculation to first 20 batches to save time
include_vis_data = log_images or should_log_ap

result_dict = self.shared_step(
train_batch, batch_idx, include_visualization_data_in_result_dict=include_vis_data
)

result_dict = self.shared_step(train_batch, batch_idx, include_visualization_data_in_result_dict=log_images)
if should_log_ap:
self.update_ap_metrics(result_dict, self.ap_training_metrics)

if log_images:
image_grids = self.visualize_predictions_channels(result_dict)
Expand Down Expand Up @@ -340,20 +352,32 @@ def log_and_reset_mean_ap(self, mode: str):
mean_ap_per_threshold = torch.zeros(len(self.maximal_gt_keypoint_pixel_distances))
metrics = self.ap_test_metrics if mode == "test" else self.ap_validation_metrics

# calculate APs for each channel and each threshold distance, and log them
print(f" # {mode} metrics:")
for channel_idx, channel_name in enumerate(self.keypoint_channel_configuration):
channel_aps = self.compute_and_log_metrics_for_channel(metrics[channel_idx], channel_name, mode)
mean_ap_per_threshold += torch.tensor(channel_aps)

# calculate the mAP over all channels for each threshold distance, and log them
for i, maximal_distance in enumerate(self.maximal_gt_keypoint_pixel_distances):
self.log(
f"{mode}/meanAP/d={float(maximal_distance):.1f}",
mean_ap_per_threshold[i] / len(self.keypoint_channel_configuration),
)

# calculate the mAP over all channels and all threshold distances, and log it
mean_ap = mean_ap_per_threshold.mean() / len(self.keypoint_channel_configuration)
self.log(f"{mode}/meanAP", mean_ap)
self.log(f"{mode}/meanAP/meanAP", mean_ap)

def training_epoch_end(self, outputs):
"""
Called on the end of a training epoch.
Used to compute and log the AP metrics.
"""
if self.is_ap_epoch():
self.log_and_reset_mean_ap("train")

def validation_epoch_end(self, outputs):
"""
Called on the end of a validation epoch.
Expand Down Expand Up @@ -396,18 +420,18 @@ def compute_and_log_metrics_for_channel(
self, metrics: KeypointAPMetrics, channel: str, training_mode: str
) -> List[float]:
"""
logs AP of predictions of single Channel for each threshold distance (as configured) for the categorization of the keypoints into a confusion matrix.
Also resets metric and returns resulting meanAP over all channels.
logs AP of predictions of single Channel for each threshold distance.
Also resets metric and returns resulting AP for all distances.
"""
# compute ap's
ap_metrics = metrics.compute()
print(f"{ap_metrics=}")
rounded_ap_metrics = {k: round(v, 3) for k, v in ap_metrics.items()}
print(f"{channel} : {rounded_ap_metrics}")
for maximal_distance, ap in ap_metrics.items():
self.log(f"{training_mode}/{channel}_ap/d={float(maximal_distance):.1f}", ap)

mean_ap = sum(ap_metrics.values()) / len(ap_metrics.values())

self.log(f"{training_mode}/{channel}_ap/meanAP", mean_ap) # log top level for wandb hyperparam chart.

metrics.reset()
return list(ap_metrics.values())

Expand Down

0 comments on commit 65be243

Please sign in to comment.