diff --git a/keypoint_detection/models/detector.py b/keypoint_detection/models/detector.py index 8e3ed59..f62b62a 100644 --- a/keypoint_detection/models/detector.py +++ b/keypoint_detection/models/detector.py @@ -315,7 +315,7 @@ def visualize_predicted_keypoints(self, result_dict): # get the keypoints from the heatmaps predicted_heatmaps = predicted_heatmaps.detach().float() predicted_keypoints = get_keypoints_from_heatmap_batch_maxpool( - predicted_heatmaps, self.max_keypoints, self.minimal_keypoint_pixel_distance, abs_max_threshold=0.2 + predicted_heatmaps, self.max_keypoints, self.minimal_keypoint_pixel_distance, abs_max_threshold=0.1 ) # overlay the images with the keypoints grid = visualize_predicted_keypoints(images, predicted_keypoints, self.keypoint_channel_configuration) diff --git a/keypoint_detection/utils/visualization.py b/keypoint_detection/utils/visualization.py index 0bca3af..b547513 100644 --- a/keypoint_detection/utils/visualization.py +++ b/keypoint_detection/utils/visualization.py @@ -20,7 +20,10 @@ def get_logging_label_from_channel_configuration(channel_configuration: List[Lis channel_name = f"{channel_configuration[0]}+{channel_configuration[1]}+..." channel_name_short = (channel_name[:40] + "...") if len(channel_name) > 40 else channel_name - label = f"{channel_name_short}_{mode}" + if mode != "": + label = f"{channel_name_short}_{mode}" + else: + label = channel_name_short return label @@ -108,7 +111,7 @@ def draw_keypoints_on_image( draw.text( (10, channel_idx * 10 * scale), - get_logging_label_from_channel_configuration(channel_configuration[channel_idx], "").split("_")[0], + get_logging_label_from_channel_configuration(channel_configuration[channel_idx], ""), fill=color_pool[channel_idx], font=ImageFont.truetype("FreeMono.ttf", size=10 * scale), )