Skip to content

Commit

Permalink
visualize all keypoints on single image
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Oct 10, 2023
1 parent d0cbb05 commit 9ffa9a2
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 34 deletions.
56 changes: 35 additions & 21 deletions keypoint_detection/models/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from keypoint_detection.models.backbones.base_backbone import Backbone
from keypoint_detection.models.metrics import DetectedKeypoint, Keypoint, KeypointAPMetrics
from keypoint_detection.utils.heatmap import BCE_loss, create_heatmap_batch, get_keypoints_from_heatmap_batch_maxpool
from keypoint_detection.utils.visualization import visualize_predicted_heatmaps
from keypoint_detection.utils.visualization import (
get_logging_label_from_channel_configuration,
visualize_predicted_heatmaps,
visualize_predicted_keypoints,
)


class KeypointDetector(pl.LightningModule):
Expand Down Expand Up @@ -266,7 +270,7 @@ def training_step(self, train_batch, batch_idx):

if log_images:
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="train")
self.log_channel_predictions_grids(image_grids, mode="train")

for channel_name in self.keypoint_channel_configuration:
self.log(f"train/{channel_name}", result_dict[f"{channel_name}_loss"])
Expand Down Expand Up @@ -299,26 +303,29 @@ def visualize_predictions_channels(self, result_dict):
image_grids.append(grid)
return image_grids

@staticmethod
def logging_label(channel_configuration, mode: str) -> str:
channel_name = channel_configuration

if isinstance(channel_configuration, list):
if len(channel_configuration) == 1:
channel_name = channel_configuration[0]
else:
channel_name = f"{channel_configuration[0]}+{channel_configuration[1]}+..."

channel_name_short = (channel_name[:40] + "...") if len(channel_name) > 40 else channel_name
label = f"{channel_name_short}_{mode}_keypoints"
return label

def log_image_grids(self, image_grids, mode: str):
def log_channel_predictions_grids(self, image_grids, mode: str):
for channel_configuration, grid in zip(self.keypoint_channel_configuration, image_grids):
label = KeypointDetector.logging_label(channel_configuration, mode)
label = get_logging_label_from_channel_configuration(channel_configuration, mode)
image_caption = "top: predicted heatmaps, bottom: gt heatmaps"
self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption)})

def visualize_predicted_keypoints(self, result_dict):
images = result_dict["input_images"]
predicted_heatmaps = result_dict["predicted_heatmaps"]
# get the keypoints from the heatmaps
predicted_heatmaps = predicted_heatmaps.detach().float()
predicted_keypoints = get_keypoints_from_heatmap_batch_maxpool(
predicted_heatmaps, self.max_keypoints, self.minimal_keypoint_pixel_distance, abs_max_threshold=0.4
)
# overlay the images with the keypoints
grid = visualize_predicted_keypoints(images, predicted_keypoints, self.keypoint_channel_configuration)
return grid

def log_predicted_keypoints(self, grid, mode=str):
label = f"predicted_keypoints_{mode}"
image_caption = "predicted keypoints"
self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption)})

def validation_step(self, val_batch, batch_idx):
# no need to switch model to eval mode, this is handled by pytorch lightning
result_dict = self.shared_step(val_batch, batch_idx, include_visualization_data_in_result_dict=True)
Expand All @@ -328,8 +335,11 @@ def validation_step(self, val_batch, batch_idx):

log_images = batch_idx == 0 and self.current_epoch > 0 and self.is_ap_epoch()
if log_images and isinstance(self.logger, pl.loggers.wandb.WandbLogger):
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="validation")
channel_grids = self.visualize_predictions_channels(result_dict)
self.log_channel_predictions_grids(channel_grids, mode="validation")

keypoint_grids = self.visualize_predicted_keypoints(result_dict)
self.log_predicted_keypoints(keypoint_grids, mode="validation")

## log (defaults to on_epoch, which aggregates the logged values over entire validation set)
self.log("validation/epoch_loss", result_dict["loss"])
Expand All @@ -342,7 +352,11 @@ def test_step(self, test_batch, batch_idx):
# only log first 10 batches to reduce storage space
if batch_idx < 10 and isinstance(self.logger, pl.loggers.wandb.WandbLogger):
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="test")
self.log_channel_predictions_grids(image_grids, mode="test")

keypoint_grids = self.visualize_predicted_keypoints(result_dict)
self.log_predicted_keypoints(keypoint_grids, mode="validation")

self.log("test/epoch_loss", result_dict["loss"])
self.log("test/gt_loss", result_dict["gt_loss"])

Expand Down
41 changes: 41 additions & 0 deletions keypoint_detection/tasks/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
""" run inference on a provided image and save the result to a file """

import numpy as np
import torch
from PIL import Image

from keypoint_detection.models.detector import KeypointDetector
from keypoint_detection.utils.heatmap import get_keypoints_from_heatmap_batch_maxpool
from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint
from keypoint_detection.utils.visualization import draw_keypoints_on_image


def run_inference(model: KeypointDetector, image, confidence_threshold: float = 0.1) -> Image:
model.eval()
tensored_image = torch.from_numpy(np.array(image)).float()
tensored_image = tensored_image / 255.0
tensored_image = tensored_image.permute(2, 0, 1)
tensored_image = tensored_image.unsqueeze(0)
with torch.no_grad():
heatmaps = model(tensored_image)

keypoints = get_keypoints_from_heatmap_batch_maxpool(heatmaps, abs_max_threshold=confidence_threshold)
image_keypoints = keypoints[0]
for keypoints, channel_config in zip(image_keypoints, model.keypoint_channel_configuration):
print(f"Keypoints for {channel_config}: {keypoints}")
image = draw_keypoints_on_image(image, image_keypoints, model.keypoint_channel_configuration)
return image


if __name__ == "__main__":
wandb_checkpoint = "tlips/synthetic-lego-battery-keypoints/model-tbzd50z8:v0"
image_path = "/home/tlips/Downloads/Lego-battery-real/0.jpg"
# image_path = "/home/tlips/Documents/synthetic-cloth-data/synthetic-cloth-data/data/datasets/LEGO-battery/01/images/0.jpg"
image_size = (256, 256)

image = Image.open(image_path)
image = image.resize(image_size)

model = get_model_from_wandb_checkpoint(wandb_checkpoint)
image = run_inference(model, image)
image.save("inference_result.png")
97 changes: 84 additions & 13 deletions keypoint_detection/utils/visualization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from argparse import ArgumentParser
from typing import List
from typing import List, Tuple

import numpy as np
import torch
import torchvision
from matplotlib import cm
from PIL import Image, ImageDraw, ImageFont

from keypoint_detection.utils.heatmap import generate_channel_heatmap


def get_logging_label_from_channel_configuration(channel_configuration: List[List[str]], mode: str) -> str:
channel_name = channel_configuration

if isinstance(channel_configuration, list):
if len(channel_configuration) == 1:
channel_name = channel_configuration[0]
else:
channel_name = f"{channel_configuration[0]}+{channel_configuration[1]}+..."

channel_name_short = (channel_name[:40] + "...") if len(channel_name) > 40 else channel_name
label = f"{channel_name_short}_{mode}"
return label


def overlay_image_with_heatmap(images: torch.Tensor, heatmaps: torch.Tensor, alpha=0.5) -> torch.Tensor:
""" """
viridis = cm.get_cmap("viridis")
Expand All @@ -19,7 +35,22 @@ def overlay_image_with_heatmap(images: torch.Tensor, heatmaps: torch.Tensor, alp
return overlayed_images


def overlay_image_with_keypoints(images: torch.Tensor, keypoints: List[torch.Tensor], sigma: float) -> torch.Tensor:
def visualize_predicted_heatmaps(
imgs: torch.Tensor,
predicted_heatmaps: torch.Tensor,
gt_heatmaps: torch.Tensor,
):
num_images = min(predicted_heatmaps.shape[0], 6)

predicted_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], predicted_heatmaps[:num_images])
gt_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], gt_heatmaps[:num_images])

images = torch.cat([predicted_heatmap_overlays, gt_heatmap_overlays])
grid = torchvision.utils.make_grid(images, nrow=num_images)
return grid


def overlay_images_with_keypoints(images: torch.Tensor, keypoints: List[torch.Tensor], sigma: float) -> torch.Tensor:
"""
images N x 3 x H x W
keypoints list of size N with Tensors C x 2
Expand Down Expand Up @@ -49,18 +80,58 @@ def overlay_image_with_keypoints(images: torch.Tensor, keypoints: List[torch.Ten
return overlayed_images


def visualize_predicted_heatmaps(
imgs: torch.Tensor,
predicted_heatmaps: torch.Tensor,
gt_heatmaps: torch.Tensor,
def draw_keypoints_on_image(
image: Image, image_keypoints: List[List[Tuple[int, int]]], channel_configuration: List[List[str]]
) -> Image:
"""adds all keypoints to the PIL image, with different colors for each channel."""
color_pool = [
"#FF00FF", # Neon Purple
"#00FF00", # Electric Green
"#FFFF00", # Cyber Yellow
"#0000FF", # Laser Blue
"#FF0000", # Radioactive Red
"#00FFFF", # Galactic Teal
"#FF00AA", # Quantum Pink
"#C0C0C0", # Holographic Silver
"#000000", # Abyssal Black
"#FFA500", # Cosmic Orange
]
image_size = image.size
min_size = min(image_size)
scale = 1 + (min_size // 256)

draw = ImageDraw.Draw(image)
for channel_idx, channel_keypoints in enumerate(image_keypoints):
for keypoint_idx, keypoint in enumerate(channel_keypoints):
u, v = keypoint
draw.ellipse((u - scale, v - scale, u + scale, v + scale), fill=color_pool[channel_idx])

draw.text(
(10, channel_idx * 10 * scale),
get_logging_label_from_channel_configuration(channel_configuration[channel_idx], "").split("_")[0],
fill=color_pool[channel_idx],
font=ImageFont.truetype("FreeMono.ttf", size=10 * scale),
)

return image


def visualize_predicted_keypoints(
images: torch.Tensor, keypoints: List[List[List[List[int]]]], channel_configuration: List[List[str]]
):
num_images = min(predicted_heatmaps.shape[0], 6)

predicted_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], predicted_heatmaps[:num_images])
gt_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], gt_heatmaps[:num_images])

images = torch.cat([predicted_heatmap_overlays, gt_heatmap_overlays])
grid = torchvision.utils.make_grid(images, nrow=num_images)
drawn_images = []
num_images = min(images.shape[0], 6)
for i in range(num_images):
# PIL expects uint8 images
image = images[i].permute(1, 2, 0).numpy() * 255
image = image.astype(np.uint8)
image = Image.fromarray(image)
image = draw_keypoints_on_image(image, keypoints[i], channel_configuration)
drawn_images.append(image)

drawn_images = torch.stack([torch.from_numpy(np.array(image)).permute(2, 0, 1) / 255 for image in drawn_images])

grid = torchvision.utils.make_grid(drawn_images, nrow=num_images)
return grid


Expand Down

0 comments on commit 9ffa9a2

Please sign in to comment.