This script aims to visualize all results obtained from other notebooks

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import os
import json
import numpy as np
from PIL import Image
import cv2
from scipy.ndimage import map_coordinates
import math
import random
import albumentations as A
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

Let's start by implementing the key functions from other notebooks. You can keep this cell collapsed.

In [None]:
class KeypointDataset(Dataset):
    def __init__(self, coco_json, img_dir, img_list=None, sigma=2, target_size=(512, 512),
                 augment=False, augment_prob=0.8, unified_cells=False):

        with open(coco_json, 'r') as f:
            self.coco_data = json.load(f)

        self._clean_invalid_keypoints()

        self.img_dir = img_dir
        self.sigma = sigma
        self.target_size = target_size
        self.augment = augment
        self.augment_prob = augment_prob
        self.unified_cells = unified_cells

        self.id_list = {ann['id'] for ann in self.coco_data['annotations']}
        self.id_to_image_id = {ann['id']: ann['image_id'] for ann in self.coco_data['annotations']}
        self.img_id_to_file = {img['id']: img['file_name'] for img in self.coco_data['images']}
        self.id_to_keypoints = {ann['id']: ann['keypoints'] for ann in self.coco_data['annotations']}

        self.img_to_anns = {}
        for ann in self.coco_data['annotations']:
            img_id = ann['image_id']
            if img_id not in self.img_to_anns:
                self.img_to_anns[img_id] = []
            self.img_to_anns[img_id].append(ann)

        if img_list:
            self.img_ids = [
                img_id for img_id in self.img_id_to_file
                if self.img_id_to_file[img_id] in img_list and img_id in self.img_to_anns
            ]
        else:
            self.img_ids = list(self.img_to_anns.keys())

        if self.augment:
            self.geometric_transform = A.Compose([
                A.Rotate(limit=15, p=0.85, border_mode=cv2.BORDER_REFLECT),
                A.HorizontalFlip(p=0.15),
                A.VerticalFlip(p=0),
                A.ShiftScaleRotate(
                    shift_limit=0.1,
                    scale_limit=0.2,
                    rotate_limit=15,
                    border_mode=cv2.BORDER_REFLECT,
                    p=0
                ),
            ], keypoint_params=A.KeypointParams(format='xy', remove_invisible=True))

            self.photometric_transform = A.Compose([
                A.RandomBrightnessContrast(
                    brightness_limit=0.3,
                    contrast_limit=0.2,
                    p=0.7
                ),
                A.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    saturation=0.2,
                    hue=0.1,
                    p=0.5
                ),
                A.RandomGamma(gamma_limit=(80, 120), p=0.4),
                A.HueSaturationValue(
                    hue_shift_limit=20,
                    sat_shift_limit=30,
                    val_shift_limit=20,
                    p=0.5
                ),
                A.OneOf([
                    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
                    A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.3),
                ], p=0.4),
                A.OneOf([
                    A.MotionBlur(blur_limit=3, p=0.3),
                    A.GaussianBlur(blur_limit=3, p=0.3),
                ], p=0.3),
            ])
        else:
            self.geometric_transform = None
            self.photometric_transform = None

    def _clean_invalid_keypoints(self):
        
        cleaned_annotations = []
        total_keypoints = 0
        invalid_keypoints = 0

        for ann in self.coco_data['annotations']:
            kpts = ann['keypoints']
            cleaned_kpts = []

            for i in range(0, len(kpts), 3):
                x, y, v = kpts[i:i + 3]
                total_keypoints += 1

               
                if np.isfinite(x) and np.isfinite(y) and not (np.isnan(x) or np.isnan(y)):
                    cleaned_kpts.extend([float(x), float(y), int(v)])
                else:
                   
                    cleaned_kpts.extend([0.0, 0.0, 0])
                    invalid_keypoints += 1

            ann['keypoints'] = cleaned_kpts
            cleaned_annotations.append(ann)

        self.coco_data['annotations'] = cleaned_annotations

    def __len__(self):
        return len(self.img_ids)

    def _extract_keypoints(self, annotations, original_width, original_height):
       
        keypoints = []
        keypoint_info = []
        invalid_count = 0

        for ann in annotations:
            kpts = ann['keypoints']
            for i in range(0, len(kpts), 3):
                x, y, v = kpts[i:i + 3]

                if not (np.isfinite(x) and np.isfinite(y)) or np.isnan(x) or np.isnan(y):
                    invalid_count += 1
                    continue  

                if v > 0:  
                    
                    if 0 <= x < original_width and 0 <= y < original_height:
                        keypoints.append((float(x), float(y)))

                        if self.unified_cells:
                            visibility = 1
                        else:
                            visibility = v

                        keypoint_info.append({
                            'visibility': visibility,
                            'label': 1 if self.unified_cells else v
                        })
                    else:
                        
                        invalid_count += 1

        

        return keypoints, keypoint_info

    def _apply_geometric_augmentation(self, image, keypoints):
       
        try:
            if keypoints:
                
                valid_keypoints = []
                for kp in keypoints:
                    if len(kp) >= 2 and np.isfinite(kp[0]) and np.isfinite(kp[1]):
                        valid_keypoints.append(kp)

                if valid_keypoints:
                    transformed = self.geometric_transform(image=image, keypoints=valid_keypoints)
                    return transformed['image'], transformed['keypoints']
                else:
                   
                    transformed = self.geometric_transform(image=image, keypoints=[])
                    return transformed['image'], []
            else:
                transformed = self.geometric_transform(image=image, keypoints=[])
                return transformed['image'], []
        except Exception as e:
            return image, keypoints

    def _apply_photometric_augmentation(self, image):
        try:
            transformed = self.photometric_transform(image=image)
            return transformed['image']
        except Exception as e:
            return image

    def _create_heatmaps_from_keypoints(self, keypoints, keypoint_info, width, height):
        if self.unified_cells:
            heatmap_unified = np.zeros((height, width), dtype=np.float32)

            for (x, y), info in zip(keypoints, keypoint_info):
               
                if np.isfinite(x) and np.isfinite(y):
                    self._add_gaussian(heatmap_unified, x, y, self.sigma)

            return heatmap_unified, np.zeros((height, width), dtype=np.float32)
        else:
            heatmap_black = np.zeros((height, width), dtype=np.float32)
            heatmap_white = np.zeros((height, width), dtype=np.float32)

            for (x, y), info in zip(keypoints, keypoint_info):
                
                if np.isfinite(x) and np.isfinite(y):
                    if info['visibility'] == 2: 
                        self._add_gaussian(heatmap_black, x, y, self.sigma)
                    elif info['visibility'] == 1:  
                        self._add_gaussian(heatmap_white, x, y, self.sigma)

            return heatmap_black, heatmap_white

    def _add_gaussian(self, heatmap, x, y, sigma):
        
        height, width = heatmap.shape

        
        if not (np.isfinite(x) and np.isfinite(y)) or np.isnan(x) or np.isnan(y):
            return 

        
        try:
            x = float(x)
            y = float(y)
            x, y = int(round(x)), int(round(y))
        except (ValueError, OverflowError):
            return  

        
        x = max(0, min(x, width - 1))
        y = max(0, min(y, height - 1))

        
        size = 6 * sigma + 1
        radius = int(size // 2)

       
        x0 = int(max(0, x - radius))
        y0 = int(max(0, y - radius))
        x1 = int(min(width, x + radius + 1))
        y1 = int(min(height, y + radius + 1))

        
        xs = np.arange(x0, x1)
        ys = np.arange(y0, y1)

        if len(xs) == 0 or len(ys) == 0:
            return 
        xx, yy = np.meshgrid(xs, ys)

       
        gaussian = np.exp(-((xx - x) ** 2 + (yy - y) ** 2) / (2 * sigma ** 2))

       
        try:
            heatmap[y0:y1, x0:x1] = np.maximum(heatmap[y0:y1, x0:x1], gaussian)
        except (ValueError, IndexError):
            pass  
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_file = self.img_id_to_file[img_id]
        img_path = os.path.join(self.img_dir, img_file)

        try:
            
            image = Image.open(img_path).convert('RGB')
            original_width, original_height = image.size

           
            image = image.resize(self.target_size, Image.LANCZOS)
            image = np.array(image)

            height, width = self.target_size
            annotations = self.img_to_anns[img_id]

            
            keypoints, keypoint_info = self._extract_keypoints(annotations, original_width, original_height)

            
            scale_x = width / original_width
            scale_y = height / original_height

            scaled_keypoints = []
            for x, y in keypoints:
                new_x = x * scale_x
                new_y = y * scale_y
                
                if np.isfinite(new_x) and np.isfinite(new_y):
                    scaled_keypoints.append((new_x, new_y))

            
            if self.augment and random.random() < self.augment_prob:
                
                image, scaled_keypoints = self._apply_geometric_augmentation(image, scaled_keypoints)

                image = self._apply_photometric_augmentation(image)

            
            heatmap_first, heatmap_second = self._create_heatmaps_from_keypoints(
                scaled_keypoints, keypoint_info[:len(scaled_keypoints)], width, height
            )

            
            if not (np.isfinite(heatmap_first).all() and np.isfinite(heatmap_second).all()):
                heatmap_first = np.nan_to_num(heatmap_first, nan=0.0, posinf=0.0, neginf=0.0)
                heatmap_second = np.nan_to_num(heatmap_second, nan=0.0, posinf=0.0, neginf=0.0)

           
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            heatmap_first = torch.from_numpy(heatmap_first).unsqueeze(0).float()
            heatmap_second = torch.from_numpy(heatmap_second).unsqueeze(0).float()

            return image, (heatmap_first, heatmap_second)

        except Exception as e:
            height, width = self.target_size
            empty_image = torch.zeros(3, height, width, dtype=torch.float32)
            empty_heatmap = torch.zeros(1, height, width, dtype=torch.float32)
            return empty_image, (empty_heatmap, empty_heatmap)

    def set_augmentation(self, enabled, prob=0.8):
        self.augment = enabled
        self.augment_prob = prob


def separate_keypoints_by_visibility(kp):
    """Separate black and white points based on their visibility"""
    x_black, y_black, x_white, y_white = [], [], [], []
    for i in range(0, len(kp), 3):
        if kp[i + 2] == 2:
            x_black.append(kp[i])
            y_black.append(kp[i + 1])
        else:
            x_white.append(kp[i])
            y_white.append(kp[i + 1])
    return x_black, y_black, x_white, y_white


def apply_transformation(image_path, kp_v, kp_xy, transform):
    """Apply transformation to image and keypoints"""
    image = Image.open(image_path).convert('RGB')
    image = np.array(image)

    augmented = transform(image=image, keypoints=kp_xy)

    transformed_kps = []
    for i in range(len(kp_xy)):
        x, y = augmented['keypoints'][i]
        v = kp_v[i]
        transformed_kps.extend([x, y, v])

    return augmented['image'], transformed_kps


def create_transformation_visualization(full_name, keypoints_v, keypoints_xy, transformations,
                                        output_pdf_path, title_prefix="Transformation", main_title=None):
    """Create visualization of transformations"""
    n = len(transformations)
    cols = 3
    rows = n

    fig = plt.figure(figsize=(cols * 5, rows * 4.5))  

    if main_tit
        fig.suptitle(main_title, fontsize=16, fontweight='bold', y=0.96) 

    with PdfPages(output_pdf_path) as pdf:
        for i, (name, transform) in enumerate(transformations.items()):
            img_aug, kps_aug = apply_transformation(full_name, keypoints_v, keypoints_xy, transform)

           
            if img_aug.dtype != np.uint8:
                if img_aug.max() <= 1.0:
                    img_aug = (img_aug * 255).clip(0, 255).astype(np.uint8)
                else:
                    img_aug = img_aug.clip(0, 255).astype(np.uint8)

            x_black, y_black, x_white, y_white = separate_keypoints_by_visibility(kps_aug)

            print(f"{title_prefix} {i}: {name}")

            
            plt.subplot(rows, cols, i * cols + 1)
            image_orig = Image.open(full_name).convert('RGB')
            plt.imshow(image_orig)
            plt.title("Original Image", fontsize=12, pad=10)
            plt.axis('off')

            
            plt.subplot(rows, cols, i * cols + 2)
            plt.imshow(img_aug)
            plt.title(f"{name}", fontsize=12, pad=10)
            plt.axis('off')

            
            plt.subplot(rows, cols, i * cols + 3)
            plt.imshow(img_aug)
            plt.scatter(x_black, y_black, c='r', label='black', s=6)  # Reduced size from 15 to 6
            plt.scatter(x_white, y_white, c='b', label='white', s=6)  # Reduced size from 15 to 6
            plt.title("Image and Annotations Overlay", fontsize=12, pad=10)
            plt.axis('off')
            plt.legend(loc='upper right', fontsize=10)

        plt.tight_layout(pad=2.0)
        pdf.savefig(fig, bbox_inches='tight')
        plt.show()
        plt.close(fig)


def apply_curvature_transform(image, curvature_factor=0.3):
    """Apply curvature transformation to simulate datamatrix bending"""
    height, width = image.shape[:2]
    y_dest, x_dest = np.mgrid[0:height, 0:width]
    x_source = x_dest - curvature_factor * np.sin(2 * np.pi * y_dest / height) * width
    y_source = y_dest.astype(float)

    result = np.zeros_like(image)
    for c in range(image.shape[2]):
        result[..., c] = map_coordinates(image[..., c], [y_source, x_source], order=1, mode='constant', cval=255)

    return result.astype(np.uint8)


def deform_keypoints_curvature(keypoints, height, width, curvature_factor=0.3):
    """Apply curvature deformation to keypoints"""
    new_keypoints = []
    for i in range(0, len(keypoints), 3):
        x, y, v = keypoints[i], keypoints[i + 1], keypoints[i + 2]
        if v > 0:
            delta_x = curvature_factor * np.sin(2 * np.pi * y / height) * width
            x_def = np.clip(x + delta_x, 0, width - 1)
            y_def = np.clip(y, 0, height - 1)
            new_keypoints.extend([x_def, y_def, v])
        else:
            new_keypoints.extend([x, y, v])
    return new_keypoints


def generate_wrinkle_points(width, height, num_wrinkles=3, min_distance=80):
    """Generate randomly distributed wrinkle points with realistic parameters"""
    wrinkle_points = []
    attempts = 0
    max_attempts = 1000

    while len(wrinkle_points) < num_wrinkles and attempts < max_attempts:
        x = random.uniform(width * 0.2, width * 0.8)
        y = random.uniform(height * 0.2, height * 0.8)

        valid = True
        for px, py, _, _ in wrinkle_points:
            if np.sqrt((x - px) ** 2 + (y - py) ** 2) < min_distance:
                valid = False
                break

        if valid:
            intensity = random.uniform(3, 8)
            radius = random.uniform(40, 80)
            wrinkle_points.append((x, y, intensity, radius))

        attempts += 1

    return wrinkle_points


def apply_wrinkle_transform(image, wrinkle_intensity=0.3, num_wrinkles=3):
    """Apply realistic wrinkle transformation to an image"""
    height, width = image.shape[:2]
    wrinkle_points = generate_wrinkle_points(width, height, num_wrinkles)

    y_dest, x_dest = np.mgrid[0:height, 0:width]
    x_source = x_dest.astype(float)
    y_source = y_dest.astype(float)

    for wx, wy, intensity, radius in wrinkle_points:
        dist = np.sqrt((x_dest - wx) ** 2 + (y_dest - wy) ** 2)
        influence = np.exp(-dist ** 2 / (2 * (radius * 1.5) ** 2))

        angle = random.uniform(0, 2 * np.pi)
        radial_factor = intensity * wrinkle_intensity * influence

        displacement_x = radial_factor * (
                np.cos(angle + dist / radius) * np.sin(dist / radius * 2) +
                0.2 * np.sin(2 * angle) * np.exp(-dist / (radius * 0.8))
        )

        displacement_y = radial_factor * (
                np.sin(angle + dist / radius) * np.sin(dist / radius * 2) +
                0.2 * np.cos(2 * angle) * np.exp(-dist / (radius * 0.8))
        )

        x_source += displacement_x
        y_source += displacement_y

    result = np.zeros_like(image)
    for c in range(image.shape[2]):
        result[..., c] = map_coordinates(
            image[..., c], [y_source, x_source], order=1, mode='constant', cval=255
        )

    displacement_x = x_source - x_dest
    displacement_y = y_source - y_dest

    return result.astype(np.uint8), displacement_x, displacement_y


def deform_keypoints_wrinkle(keypoints, displacement_x, displacement_y, width, height):
    """Apply wrinkle transformation to keypoints"""
    new_keypoints = []
    for i in range(0, len(keypoints), 3):
        x, y, v = keypoints[i], keypoints[i + 1], keypoints[i + 2]
        if v > 0:
            x = np.clip(x, 0, width - 1)
            y = np.clip(y, 0, height - 1)

            x_int, y_int = int(x), int(y)
            x_int = min(x_int, displacement_x.shape[1] - 1)
            y_int = min(y_int, displacement_x.shape[0] - 1)

            dx = displacement_x[y_int, x_int]
            dy = displacement_y[y_int, x_int]

            x_deformed = np.clip(x + dx, 0, width - 1)
            y_deformed = np.clip(y + dy, 0, height - 1)
            new_keypoints.extend([x_deformed, y_deformed, v])
        else:
            new_keypoints.extend([x, y, v])
    return new_keypoints


def normalize_heatmap(hm):
    """Normalize heatmap values"""
    return (hm - np.min(hm)) / (np.max(hm) - np.min(hm) + 1e-8)


class KeypointDetector(nn.Module):
    def __init__(self, backbone='resnet18', pretrained=True):
        super().__init__()
        if backbone == 'resnet18':
            if pretrained:
                self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
            else:
                self.backbone = models.resnet18(weights=None)
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_features = 512
        else:
            raise ValueError(f"Backbone '{backbone}' not supported.")

        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(backbone_features, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            nn.Conv2d(16, 2, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.backbone(x)
        heatmaps = self.upsample(features)
        return heatmaps


Please fill in the following variables with your custom paths:
coco = path to your annotation JSON file,
file = path to your dataset,
id = identifier of an image from your dataset on which you want to visualize all elements defined in this script

In [None]:
coco =
file =
id =

Z = KeypointDataset(coco_json=coco, img_dir=file)

id_image = Z.id_to_image_id[id]
filename = Z.img_id_to_file[id_image]
full_name = file + '/' + filename
keypoints = Z.id_to_keypoints[id]

x_black_cell, y_black_cell, x_white_cell, y_white_cell = separate_keypoints_by_visibility(keypoints)

img = Image.open(full_name).convert("RGB")
img_np = np.array(img)

First, let's visualize the overlay of annotations on the image. This will serve as a foundation for building our ground truth.

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
output_pdf_path1 = 'Visualization_annotations.pdf'

with PdfPages(output_pdf_path1) as pdf:
    img = mpimg.imread(full_name)
    ax.imshow(img, cmap="grey")
    ax.set_title('Visualization of Annotations with Cell Polarity Distinctions', fontsize=14, pad=20)

    ax.scatter(x_black_cell, y_black_cell, c='red', s=8, label='Black Cells', alpha=0.8)
    ax.scatter(x_white_cell, y_white_cell, c='blue', s=8, label='White Cells', alpha=0.8)

    ax.legend(loc='upper right', fontsize=12, frameon=True, fancybox=True, shadow=True)
    ax.axis('off')

    plt.tight_layout(pad=2.0)
    pdf.savefig(fig, bbox_inches='tight')
    plt.show()
    plt.close(fig)

Next, depending on your dataset size, you may have performed offline transformations using Albumentations, expanding your dataset size without duplicating the original images. Here is the visualization of this:

In [None]:
keypoints_x = keypoints[::3]
keypoints_y = keypoints[1::3]
keypoints_v = keypoints[2::3]
keypoints_xy = [(x, y) for x, y in zip(keypoints_x, keypoints_y)]

offline_transformations = {
    "rot1": A.Compose([A.Rotate(limit=(0, 90), p=1)], keypoint_params=A.KeypointParams(format='xy')),
    "rot2": A.Compose([A.Rotate(limit=(0, 90), p=1)], keypoint_params=A.KeypointParams(format='xy')),
    "noise": A.Compose([A.GaussNoise(var_limit=(10.0, 50.0), p=1)]),
    "contrast": A.Compose([
        A.RandomBrightnessContrast(
            brightness_limit=0.3,
            contrast_limit=0.15,
            p=1.0
        )
    ], keypoint_params=A.KeypointParams(format='xy'))
}

output_pdf_path2 = 'Offline_transformations.pdf'
create_transformation_visualization(full_name, keypoints_v, keypoints_xy,
                                    offline_transformations, output_pdf_path2,
                                    _, _)

Additionally, during training you had the option to perform online transformations with Albumentations to prevent overfitting, ensuring that the images seen by our model at each epoch are never identical. Here is the visualization of this:

In [None]:
online_transformations = {
    "online_transformation": A.Compose([
        A.Rotate(limit=15, p=0.85, border_mode=cv2.BORDER_REFLECT),
        A.HorizontalFlip(p=0.15),
        A.VerticalFlip(p=0),
        A.ShiftScaleRotate(
            shift_limit=0.1,
            scale_limit=0.2,
            rotate_limit=15,
            border_mode=cv2.BORDER_REFLECT,
            p=0
        ),
        A.RandomBrightnessContrast(
            brightness_limit=0.3,
            contrast_limit=0.3,
            p=0.7
        ),
        A.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.1,
            p=0.5
        ),
        A.RandomGamma(gamma_limit=(60, 140), p=1),
        A.HueSaturationValue(
            hue_shift_limit=30,
            sat_shift_limit=40,
            val_shift_limit=30,
            p=0.7
        ),
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.3),
        ], p=0.4),
        A.OneOf([
            A.MotionBlur(blur_limit=3, p=0.3),
            A.GaussianBlur(blur_limit=3, p=0.3),
        ], p=0.3),
    ], keypoint_params=A.KeypointParams(format='xy'))
}

output_pdf_path3 = 'Online_transformations_4.pdf'
create_transformation_visualization(full_name, keypoints_v, keypoints_xy,
                                    online_transformations, output_pdf_path3,
                                    _, _)

Next, you likely applied a transformation simulating the curvature of a datamatrix code, here is the visualization:

In [None]:
img = Image.open(full_name).convert("RGB")
img_np = np.array(img)

curvature = 0.2
deformed_img_np = apply_curvature_transform(img_np, curvature)
new_keypoints = deform_keypoints_curvature(keypoints, img_np.shape[0], img_np.shape[1], curvature)

x_black, y_black, x_white, y_white = separate_keypoints_by_visibility(new_keypoints)

fig, axs = plt.subplots(1, 3, figsize=(18, 6), dpi=300)

axs[0].imshow(img_np)
axs[0].set_title('Original Image', fontsize=14, pad=15)
axs[0].axis('off')

axs[1].imshow(deformed_img_np)
axs[1].set_title('Curvature Transformed Image', fontsize=14, pad=15)
axs[1].axis('off')

axs[2].imshow(deformed_img_np)
axs[2].scatter(x_black, y_black, c='red', label='Black Keypoints (v=2)', s=12) 
axs[2].scatter(x_white, y_white, c='blue', label='White Keypoints (v≠2)', s=12, edgecolors='black')  
axs[2].set_title('Transformed Image with Keypoints', fontsize=14, pad=15)
axs[2].legend(loc='upper right', fontsize=10)
axs[2].axis('off')

plt.tight_layout(pad=3.0)

output_pdf_path4 = 'Curvature.pdf'
with PdfPages(output_pdf_path4) as pdf:

    pdf.savefig(fig, bbox_inches='tight')
    plt.show()
    plt.close(fig)

Finally, you likely applied a transformation simulating the crumpling of a datamatrix code, here is the visualization:

In [None]:
wrinkle_intensity = 0.3
num_wrinkles = 3
deformed_img_np, displacement_x, displacement_y = apply_wrinkle_transform(img_np, wrinkle_intensity, num_wrinkles)
new_keypoints = deform_keypoints_wrinkle(keypoints, displacement_x, displacement_y, img_np.shape[1], img_np.shape[0])

x_black_orig, y_black_orig, x_white_orig, y_white_orig = separate_keypoints_by_visibility(keypoints)
x_black_new, y_black_new, x_white_new, y_white_new = separate_keypoints_by_visibility(new_keypoints)

fig, axs = plt.subplots(1, 3, figsize=(18, 6), dpi=300)


axs[0].imshow(img_np)
axs[0].set_title('Original Image', fontsize=14, pad=15)
axs[0].axis('off')


axs[1].imshow(deformed_img_np)
axs[1].set_title('Crumpled Image', fontsize=14, pad=15)
axs[1].axis('off')


axs[2].imshow(deformed_img_np)
axs[2].scatter(x_black_new, y_black_new, c='red', label='Black Cells (v=2)', s=12)  
axs[2].scatter(x_white_new, y_white_new, c='blue', label='White Cells (v=1)', s=12, edgecolors='black')  
axs[2].legend(loc='upper right', fontsize=10)
axs[2].set_title('Crumpled Image + Keypoints', fontsize=14, pad=15)
axs[2].axis('off')

plt.tight_layout(pad=3.0)

output_pdf_path5 = 'Crumpling.pdf'
with PdfPages(output_pdf_path5) as pdf:

    pdf.savefig(fig, bbox_inches='tight')
    plt.show()
    plt.close(fig)

On the other hand, you had to build a ground truth using Gaussians which enabled the progressive construction of predicted heatmaps during training by minimizing the loss applied to the predicted heatmaps and ground truth. Additionally, you had to construct a binary heatmap corresponding to the binarization of the ground truth, which allowed you to apply associative embedding. Here is the visualization of the ground truth and binary heatmaps:

In [None]:
train_img_list = sorted([f for f in os.listdir(file) if f.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))])
image_name = train_img_list[0] if train_img_list else filename
base_name = os.path.splitext(image_name)[0]

train_dataset = KeypointDataset(
    coco_json=coco,
    img_dir=file,
    img_list=[image_name]
)
sigma = 2

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)


def f(d):
    return np.exp(-(d ** 2) / (2 * sigma ** 2))


lim = f(1.5 * sigma)

images, targets = next(iter(train_loader))
image_tensor = images[0]
heatmap_black_tensor = targets[0]
heatmap_white_tensor = targets[1]

image = image_tensor.permute(1, 2, 0).numpy()
heatmap_black = heatmap_black_tensor.squeeze().numpy()
heatmap_white = heatmap_white_tensor.squeeze().numpy()

heatmap_black_norm = normalize_heatmap(heatmap_black)
heatmap_white_norm = normalize_heatmap(heatmap_white)

H, W = heatmap_black_norm.shape
binary_heatmap_black = (heatmap_black >= lim).astype(float)
binary_heatmap_white = (heatmap_white >= lim).astype(float)

overlay_rgb_norm = np.zeros((H, W, 3))
overlay_rgb_norm[..., 0] = heatmap_black_norm  
overlay_rgb_norm[..., 2] = heatmap_white_norm  
overlay_rgb_binary = np.zeros((H, W, 3))
overlay_rgb_binary[..., 0] = binary_heatmap_black  
overlay_rgb_binary[..., 2] = binary_heatmap_white  

fig, axs = plt.subplots(1, 3, figsize=(18, 6))

axs[0].imshow(image)
axs[0].set_title('Original Image', fontsize=14, pad=15)
axs[0].axis('off')

axs[1].imshow(overlay_rgb_norm)
axs[1].set_title('Overlay: Normalized Heatmaps', fontsize=14, pad=15)
axs[1].axis('off')

axs[2].imshow(overlay_rgb_binary)
axs[2].set_title('Overlay: Binary Heatmaps', fontsize=14, pad=15)
axs[2].axis('off')

plt.tight_layout(pad=3.0)
output_path6 = f'heatmaps.pdf'
plt.savefig(output_path6, format='pdf', dpi=300)
plt.show()
plt.close(fig)

Finally, before starting training, the predicted heatmaps are initialized by the weights loaded from a ResNet 18, 34, or 50 backbone. Here is the visualization of the backbone output:

In [None]:
model = KeypointDetector(backbone='resnet18', pretrained=True)
predicted_heatmaps = model(image_tensor.unsqueeze(0))

heatmap_np = predicted_heatmaps.squeeze().cpu().detach().numpy()

fig, axs = plt.subplots(1, 3, figsize=(15, 5))  
axs[0].imshow(image)
axs[0].set_title('Original Image', fontsize=12, pad=15)
axs[0].axis('off')

axs[1].imshow(heatmap_np[0], cmap='hot')
axs[1].set_title('Predicted Heatmap Black\n(Channel 0)', fontsize=12, pad=15)
axs[1].axis('off')

axs[2].imshow(heatmap_np[1], cmap='gray')
axs[2].set_title('Predicted Heatmap White\n(Channel 1)', fontsize=12, pad=15)
axs[2].axis('off')

plt.subplots_adjust(wspace=0.25, hspace=0.4)

output_path7 = f'Backbone.pdf'
plt.savefig(output_path7, format='pdf', dpi=300, bbox_inches='tight')

print(f"Saved to: {output_path7}")
plt.show()
plt.close(fig)