diff --git a/keypoint_detection/models/detector.py b/keypoint_detection/models/detector.py index d557bb4..6994bdb 100644 --- a/keypoint_detection/models/detector.py +++ b/keypoint_detection/models/detector.py @@ -10,7 +10,11 @@ from keypoint_detection.models.backbones.base_backbone import Backbone from keypoint_detection.models.metrics import DetectedKeypoint, Keypoint, KeypointAPMetrics from keypoint_detection.utils.heatmap import BCE_loss, create_heatmap_batch, get_keypoints_from_heatmap_batch_maxpool -from keypoint_detection.utils.visualization import visualize_predicted_heatmaps +from keypoint_detection.utils.visualization import ( + get_logging_label_from_channel_configuration, + visualize_predicted_heatmaps, + visualize_predicted_keypoints, +) class KeypointDetector(pl.LightningModule): @@ -266,7 +270,7 @@ def training_step(self, train_batch, batch_idx): if log_images: image_grids = self.visualize_predictions_channels(result_dict) - self.log_image_grids(image_grids, mode="train") + self.log_channel_predictions_grids(image_grids, mode="train") for channel_name in self.keypoint_channel_configuration: self.log(f"train/{channel_name}", result_dict[f"{channel_name}_loss"]) @@ -299,26 +303,29 @@ def visualize_predictions_channels(self, result_dict): image_grids.append(grid) return image_grids - @staticmethod - def logging_label(channel_configuration, mode: str) -> str: - channel_name = channel_configuration - - if isinstance(channel_configuration, list): - if len(channel_configuration) == 1: - channel_name = channel_configuration[0] - else: - channel_name = f"{channel_configuration[0]}+{channel_configuration[1]}+..." - - channel_name_short = (channel_name[:40] + "...") if len(channel_name) > 40 else channel_name - label = f"{channel_name_short}_{mode}_keypoints" - return label - - def log_image_grids(self, image_grids, mode: str): + def log_channel_predictions_grids(self, image_grids, mode: str): for channel_configuration, grid in zip(self.keypoint_channel_configuration, image_grids): - label = KeypointDetector.logging_label(channel_configuration, mode) + label = get_logging_label_from_channel_configuration(channel_configuration, mode) image_caption = "top: predicted heatmaps, bottom: gt heatmaps" self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption)}) + def visualize_predicted_keypoints(self, result_dict): + images = result_dict["input_images"] + predicted_heatmaps = result_dict["predicted_heatmaps"] + # get the keypoints from the heatmaps + predicted_heatmaps = predicted_heatmaps.detach().float() + predicted_keypoints = get_keypoints_from_heatmap_batch_maxpool( + predicted_heatmaps, self.max_keypoints, self.minimal_keypoint_pixel_distance, abs_max_threshold=0.4 + ) + # overlay the images with the keypoints + grid = visualize_predicted_keypoints(images, predicted_keypoints, self.keypoint_channel_configuration) + return grid + + def log_predicted_keypoints(self, grid, mode=str): + label = f"predicted_keypoints_{mode}" + image_caption = "predicted keypoints" + self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption)}) + def validation_step(self, val_batch, batch_idx): # no need to switch model to eval mode, this is handled by pytorch lightning result_dict = self.shared_step(val_batch, batch_idx, include_visualization_data_in_result_dict=True) @@ -328,8 +335,11 @@ def validation_step(self, val_batch, batch_idx): log_images = batch_idx == 0 and self.current_epoch > 0 and self.is_ap_epoch() if log_images and isinstance(self.logger, pl.loggers.wandb.WandbLogger): - image_grids = self.visualize_predictions_channels(result_dict) - self.log_image_grids(image_grids, mode="validation") + channel_grids = self.visualize_predictions_channels(result_dict) + self.log_channel_predictions_grids(channel_grids, mode="validation") + + keypoint_grids = self.visualize_predicted_keypoints(result_dict) + self.log_predicted_keypoints(keypoint_grids, mode="validation") ## log (defaults to on_epoch, which aggregates the logged values over entire validation set) self.log("validation/epoch_loss", result_dict["loss"]) @@ -342,7 +352,11 @@ def test_step(self, test_batch, batch_idx): # only log first 10 batches to reduce storage space if batch_idx < 10 and isinstance(self.logger, pl.loggers.wandb.WandbLogger): image_grids = self.visualize_predictions_channels(result_dict) - self.log_image_grids(image_grids, mode="test") + self.log_channel_predictions_grids(image_grids, mode="test") + + keypoint_grids = self.visualize_predicted_keypoints(result_dict) + self.log_predicted_keypoints(keypoint_grids, mode="validation") + self.log("test/epoch_loss", result_dict["loss"]) self.log("test/gt_loss", result_dict["gt_loss"]) diff --git a/keypoint_detection/tasks/inference.py b/keypoint_detection/tasks/inference.py new file mode 100644 index 0000000..b6bae63 --- /dev/null +++ b/keypoint_detection/tasks/inference.py @@ -0,0 +1,41 @@ +""" run inference on a provided image and save the result to a file """ + +import numpy as np +import torch +from PIL import Image + +from keypoint_detection.models.detector import KeypointDetector +from keypoint_detection.utils.heatmap import get_keypoints_from_heatmap_batch_maxpool +from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint +from keypoint_detection.utils.visualization import draw_keypoints_on_image + + +def run_inference(model: KeypointDetector, image, confidence_threshold: float = 0.1) -> Image: + model.eval() + tensored_image = torch.from_numpy(np.array(image)).float() + tensored_image = tensored_image / 255.0 + tensored_image = tensored_image.permute(2, 0, 1) + tensored_image = tensored_image.unsqueeze(0) + with torch.no_grad(): + heatmaps = model(tensored_image) + + keypoints = get_keypoints_from_heatmap_batch_maxpool(heatmaps, abs_max_threshold=confidence_threshold) + image_keypoints = keypoints[0] + for keypoints, channel_config in zip(image_keypoints, model.keypoint_channel_configuration): + print(f"Keypoints for {channel_config}: {keypoints}") + image = draw_keypoints_on_image(image, image_keypoints, model.keypoint_channel_configuration) + return image + + +if __name__ == "__main__": + wandb_checkpoint = "tlips/synthetic-lego-battery-keypoints/model-tbzd50z8:v0" + image_path = "/home/tlips/Downloads/Lego-battery-real/0.jpg" + # image_path = "/home/tlips/Documents/synthetic-cloth-data/synthetic-cloth-data/data/datasets/LEGO-battery/01/images/0.jpg" + image_size = (256, 256) + + image = Image.open(image_path) + image = image.resize(image_size) + + model = get_model_from_wandb_checkpoint(wandb_checkpoint) + image = run_inference(model, image) + image.save("inference_result.png") diff --git a/keypoint_detection/utils/visualization.py b/keypoint_detection/utils/visualization.py index ea17a1d..0bca3af 100644 --- a/keypoint_detection/utils/visualization.py +++ b/keypoint_detection/utils/visualization.py @@ -1,13 +1,29 @@ from argparse import ArgumentParser -from typing import List +from typing import List, Tuple +import numpy as np import torch import torchvision from matplotlib import cm +from PIL import Image, ImageDraw, ImageFont from keypoint_detection.utils.heatmap import generate_channel_heatmap +def get_logging_label_from_channel_configuration(channel_configuration: List[List[str]], mode: str) -> str: + channel_name = channel_configuration + + if isinstance(channel_configuration, list): + if len(channel_configuration) == 1: + channel_name = channel_configuration[0] + else: + channel_name = f"{channel_configuration[0]}+{channel_configuration[1]}+..." + + channel_name_short = (channel_name[:40] + "...") if len(channel_name) > 40 else channel_name + label = f"{channel_name_short}_{mode}" + return label + + def overlay_image_with_heatmap(images: torch.Tensor, heatmaps: torch.Tensor, alpha=0.5) -> torch.Tensor: """ """ viridis = cm.get_cmap("viridis") @@ -19,7 +35,22 @@ def overlay_image_with_heatmap(images: torch.Tensor, heatmaps: torch.Tensor, alp return overlayed_images -def overlay_image_with_keypoints(images: torch.Tensor, keypoints: List[torch.Tensor], sigma: float) -> torch.Tensor: +def visualize_predicted_heatmaps( + imgs: torch.Tensor, + predicted_heatmaps: torch.Tensor, + gt_heatmaps: torch.Tensor, +): + num_images = min(predicted_heatmaps.shape[0], 6) + + predicted_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], predicted_heatmaps[:num_images]) + gt_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], gt_heatmaps[:num_images]) + + images = torch.cat([predicted_heatmap_overlays, gt_heatmap_overlays]) + grid = torchvision.utils.make_grid(images, nrow=num_images) + return grid + + +def overlay_images_with_keypoints(images: torch.Tensor, keypoints: List[torch.Tensor], sigma: float) -> torch.Tensor: """ images N x 3 x H x W keypoints list of size N with Tensors C x 2 @@ -49,18 +80,58 @@ def overlay_image_with_keypoints(images: torch.Tensor, keypoints: List[torch.Ten return overlayed_images -def visualize_predicted_heatmaps( - imgs: torch.Tensor, - predicted_heatmaps: torch.Tensor, - gt_heatmaps: torch.Tensor, +def draw_keypoints_on_image( + image: Image, image_keypoints: List[List[Tuple[int, int]]], channel_configuration: List[List[str]] +) -> Image: + """adds all keypoints to the PIL image, with different colors for each channel.""" + color_pool = [ + "#FF00FF", # Neon Purple + "#00FF00", # Electric Green + "#FFFF00", # Cyber Yellow + "#0000FF", # Laser Blue + "#FF0000", # Radioactive Red + "#00FFFF", # Galactic Teal + "#FF00AA", # Quantum Pink + "#C0C0C0", # Holographic Silver + "#000000", # Abyssal Black + "#FFA500", # Cosmic Orange + ] + image_size = image.size + min_size = min(image_size) + scale = 1 + (min_size // 256) + + draw = ImageDraw.Draw(image) + for channel_idx, channel_keypoints in enumerate(image_keypoints): + for keypoint_idx, keypoint in enumerate(channel_keypoints): + u, v = keypoint + draw.ellipse((u - scale, v - scale, u + scale, v + scale), fill=color_pool[channel_idx]) + + draw.text( + (10, channel_idx * 10 * scale), + get_logging_label_from_channel_configuration(channel_configuration[channel_idx], "").split("_")[0], + fill=color_pool[channel_idx], + font=ImageFont.truetype("FreeMono.ttf", size=10 * scale), + ) + + return image + + +def visualize_predicted_keypoints( + images: torch.Tensor, keypoints: List[List[List[List[int]]]], channel_configuration: List[List[str]] ): - num_images = min(predicted_heatmaps.shape[0], 6) - - predicted_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], predicted_heatmaps[:num_images]) - gt_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], gt_heatmaps[:num_images]) - - images = torch.cat([predicted_heatmap_overlays, gt_heatmap_overlays]) - grid = torchvision.utils.make_grid(images, nrow=num_images) + drawn_images = [] + num_images = min(images.shape[0], 6) + for i in range(num_images): + # PIL expects uint8 images + image = images[i].permute(1, 2, 0).numpy() * 255 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image = draw_keypoints_on_image(image, keypoints[i], channel_configuration) + drawn_images.append(image) + + drawn_images = torch.stack([torch.from_numpy(np.array(image)).permute(2, 0, 1) / 255 for image in drawn_images]) + + grid = torchvision.utils.make_grid(drawn_images, nrow=num_images) return grid