# Example of train with a KeypointRCNN_ResNet50_FPN pre-trained model

In the code below, we are going to use the KeypointRCNN_ResNet50_FPN_Weights model pre-trained on the COCO dataset to train on our own dataset.

Import libraries

In [None]:
import json
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import kornia as K
import matplotlib.pyplot as plt
import torchvision
from tqdm import tqdm
from tqdm import tqdm
import datetime
import os


In [None]:
class CustomDataset(Dataset):
    def __init__(self, images_path, annotations_path, use_augmentation, device, downsample_factor=1):
        self.images_path = images_path
        self.annotations_path = annotations_path
        self.device = device
        self.downsample_factor = downsample_factor
        self.image_filenames = [filename for filename in os.listdir(images_path) if filename.endswith('.jpg')]

        if use_augmentation:
            # Declare an augmentation pipeline
            self.transform =K.augmentation.AugmentationSequential(
                K.augmentation.RandomVerticalFlip(), 
                K.augmentation.RandomHorizontalFlip(),
                K.augmentation.RandomRotation(30),
                K.augmentation.RandomBrightness(0.5, 1.5),
                K.augmentation.RandomContrast(0.5, 1.5),
                data_keys=["input", "bbox", "keypoints"])
        else:
            self.transform = K.augmentation.AugmentationSequential(
                K.augmentation.RandomRotation(0),
                data_keys=["input", "bbox", "keypoints"])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):

        image_filename = self.image_filenames[idx]
        image_tensor = self.load_image(os.path.join(self.images_path, image_filename))
        annotation_filename = os.path.join(self.annotations_path, image_filename[:-4] + '.json')

        with open(annotation_filename, 'r') as f:
            bbox_tensor, keypoint_tensor = self.convert_to_kornia_format(json.load(f))
    
        out_tensor = self.transform(image_tensor.float(), bbox_tensor.float(), keypoint_tensor.float())
        
        # plot augmentation to test it
        if False:
            img_out = self.plot_resulting_image(
            out_tensor[0][0],
            out_tensor[1].int(),
            out_tensor[2].int(),
            )
            # plot the image
            plt.imshow(K.tensor_to_image(image_tensor.mul(255).byte()).copy())
            plt.show()
            plt.imshow(img_out)
            plt.show()

        # get the torch format from kornia format
        target = self.kornia_to_torch_format( out_tensor[1], out_tensor[2], idx)

        # return tensors
        return image_tensor, target

    def load_image(self, image_path: str)-> torch.Tensor:
        """
        Method to load image
        params:
            image_path: str = path of the image
        return:
            tensor: torch.Tensor = image tensor
        """

        # load image and convert to tensor
        image: np.ndarray = cv2.imread(image_path)

        # convert to tensor
        image_tensor: torch.Tensor = K.image_to_tensor(image)

        # bgr to rgb
        image_tensor = K.color.bgr_to_rgb(image_tensor)

        return K.enhance.normalize(image_tensor, torch.tensor(0.), torch.tensor(255.)).to(self.device)
    
    def convert_to_kornia_format(self, data):
        """
        Method to convert the bounding boxes and keypoints to the Kornia format
        params:
            data: dict = dictionary containing the bounding boxes and keypoints
        return:
            bbox_tensor: torch.Tensor = tensor containing the bounding boxes
            keypoint_tensor: torch.Tensor = tensor containing the keypoints
        """

        # Extract the bounding boxes and keypoints from the dictionary
        bboxes = data['bboxes']
        keypoints = data['keypoints']

        # Convert the bounding boxes to the Kornia format
        bbox_list = []
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            bbox_list.append([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
        bbox_tensor = torch.tensor(bbox_list).unsqueeze(0).to(self.device)

        # Convert the keypoints to the Kornia format
        keypoint_list = []
        for kpts in keypoints:
            for kpt in kpts:
                x, y, _ = kpt
                keypoint_list.append([x, y])
        keypoint_tensor = torch.tensor(keypoint_list).unsqueeze(0).to(self.device)
        return bbox_tensor, keypoint_tensor
    
    def plot_resulting_image(self, img, bbox, keypoints):
        """
        Plot the resulting image with bounding boxes and keypoints.
        params:
            img: torch.Tensor = image tensor
            bbox: torch.Tensor = bounding box tensor
            keypoints: torch.Tensor = keypoints tensor
        return:
            img_draw = image with bounding boxes and keypoints
        """
        img_array = K.tensor_to_image(img.mul(255).byte()).copy()
        img_draw = cv2.polylines(img_array, bbox.reshape(-1, 4, 2).cpu().numpy(), isClosed=True, color=(255, 0, 0))
        for k in keypoints[0]:
            img_draw = cv2.circle(img_draw, tuple(k.cpu().numpy()[:2]), radius=6, color=(255, 0, 0), thickness=-1)
        return img_draw

    def kornia_to_torch_format(self, bbox_tensor, keypoint_tensor, idx,  labels=None):
        """
        Convert bbox_tensor and keypoint_tensor in Kornia format to torch's expected format.
        
        Parameters:
        - bbox_tensor (torch.Tensor): Bounding box tensor in Kornia format
        - keypoint_tensor (torch.Tensor): Keypoint tensor in Kornia format
        - idx = index of the image
        - labels (list[int]): List of class labels for each bounding box. If None, default to label=1 for all boxes.
        
        Returns:
        - dict: A dictionary containing the following keys:
            - boxes (torch.Tensor): Bounding box tensor in torch's expected format
            - labels (torch.Tensor): Class label tensor
            - image_id (torch.Tensor): Image ID tensor
            - area (torch.Tensor): Area tensor
            - iscrowd (torch.Tensor): IsCrowd tensor
            - keypoints (torch.Tensor): Keypoint tensor in torch's expected format
        """
        # Convert bbox_tensor from Kornia's format to torch's [x1, y1, x2, y2] format
        boxes = torch.stack([bbox_tensor[0,:,0,0], bbox_tensor[0,:,0,1], bbox_tensor[0,:,2,0], bbox_tensor[0,:,2,1]], dim=1)
        
        # If labels aren't provided, assume a default label of 1 for all bounding boxes
        if labels is None:
            labels = torch.ones((bbox_tensor.shape[1],), dtype=torch.int64).to(self.device)
        else:
            labels = torch.tensor(labels, dtype=torch.int64).to(self.device)
        
        # Convert keypoint_tensor to the desired [x, y, visibility] format
        keypoints = torch.zeros((bbox_tensor.shape[1], keypoint_tensor.shape[1]//bbox_tensor.shape[1], 3)).to(self.device)
        for i in range(bbox_tensor.shape[1]):
            keypoints[i, :, :2] = keypoint_tensor[0, i*2:(i+1)*2, :]
            keypoints[i, :, 2] = 1  # setting visibility to 1
        
        return {"boxes": boxes, "labels": labels, "keypoints": keypoints, "image_id ": torch.tensor([idx]).to(self.device), "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), "iscrowd": torch.zeros((bbox_tensor.shape[1],), dtype=torch.int64).to(self.device)}

def collate_fn(batch):
    """
    Collate function to format the batch in the desired manner.
    
    Parameters:
    - batch (list): List of tuples where each tuple contains an image tensor and its associated target.
    
    Returns:
    - tuple: Tuple containing a tensor of images and a list of targets.
    """
    # Separate images and targets in the batch
    images, targets = zip(*batch)
    
    return images, targets


# if cuda is avaliable, use it
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

train_images_path = 'dataset/train/images'
train_annotations_path = 'dataset/train/annotations'
test_images_path = 'dataset/test/images'
test_annotations_path = 'dataset/test/annotations'

train_dataset = CustomDataset(train_images_path, train_annotations_path,use_augmentation = True, device = device)
test_dataset = CustomDataset(test_images_path, test_annotations_path,  use_augmentation = False, device = device)

train_dataloader = DataLoader(train_dataset, batch_size = 16, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size = 8, shuffle=False, collate_fn=collate_fn)

# Get a batch of images and annotations from the train dataloader
images, targets = next(iter(train_dataloader))

print(len(images))
print(len(targets))

# TRAIN

In [None]:
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False,
                                                                   pretrained_backbone=True,
                                                                   num_keypoints=2,
                                                                   num_classes = 2,
                                                                   trainable_backbone_layers=3)
model.to(device).train()
print(model)

In [None]:
# Define some hyperparameters
num_epochs = 300
lr = 0.001

# Create a directory with the current timestamp
now = datetime.datetime.now()
timestamp = now.strftime('%Y-%m-%d_%H-%M-%S')
save_dir = os.path.join("saved_models", timestamp)
os.makedirs(save_dir, exist_ok=True)

# select parameters to finetune
params = [p for p in model.parameters() if p.requires_grad]

# Define the optimizer
optimizer = torch.optim.Adam(params, lr=lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)

best_loss = float('inf')  # Initialize with a high value

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    model.train()  # Set model to training mode
    
    running_loss = 0.0
    for images, targets in tqdm(train_dataloader):
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        loss_dict = model(images, targets)
        
        # Compute total loss
        losses = sum(loss for loss in loss_dict.values())
        
        # Backward pass and optimize
        losses.backward()
        optimizer.step()
        lr_scheduler.step()
        
        # Print statistics
        running_loss += losses.item()

    # Validation phase
    val_loss = 0.0
    
    with torch.no_grad():
        for images, targets in tqdm(test_dataloader):

            # Forward pass
            loss_dict = model(images, targets)
            
            # Compute total loss
            losses = sum(loss for loss in loss_dict.values())

            # Accumulate validation loss
            val_loss += losses.item()
            

    # Compute average loss for the epoch
    avg_train_loss = running_loss / len(train_dataloader)
    print(f"Training Loss: {avg_train_loss}")

    # Compute average validation loss for the epoch
    avg_val_loss = val_loss / len(test_dataloader)
    print(f"Validation Loss: {avg_val_loss}")

    # Save the model if the validation loss improved
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        print(f"Improved validation loss at epoch {epoch+1}. Saving model...")
        torch.save(model.state_dict(), os.path.join(save_dir, f"model_best_epoch_{epoch+1}.pth"))
    
    # empty cuda cache
    torch.cuda.empty_cache()

print('Finished Training')


# Inference
Here the code to test the model on the test set is provided.

In [None]:
# Get the list of directories in saved_models
dirs = os.listdir("saved_models")

# Sort the directories by creation time
dirs = sorted(dirs, key=lambda x: os.path.getctime(os.path.join("saved_models", x)))

# Get the path of the most recent directory
latest_dir = os.path.join("saved_models", dirs[-1])

# Get the path of the last saved model inside the most recent directory
model_path = os.path.join(latest_dir,sorted(os.listdir(latest_dir))[-1])

# load model 
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False,
                                                                   pretrained_backbone=True,
                                                                   num_keypoints=2,
                                                                   num_classes = 2,
                                                                   trainable_backbone_layers=3)
model.load_state_dict(torch.load(model_path))
model.eval().to(device)

# load image
img_path = 'dataset/test/images/IMG_4913_JPG_jpg.rf.4f67c223e9cbf0ed07236bfe142aaaee.jpg'
image = cv2.imread(img_path)

# convert to tensor
image_tensor: torch.Tensor = K.image_to_tensor(image).to(device)

# bgr to rgb
image_tensor = K.color.bgr_to_rgb(image_tensor)

# normalize
image_tensor = K.enhance.normalize(image_tensor, torch.tensor(0.), torch.tensor(255.)).to(device)

# add batch dimension
image_tensor = image_tensor.unsqueeze(0)

# inference

with torch.no_grad():
    predictions = model(image_tensor)
    print(predictions)

In [None]:
output_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

keypoints = predictions[0]['keypoints'].detach().cpu().numpy()
boxes = predictions[0]['boxes'].detach().cpu().numpy()

for i in range(len(predictions[0]['scores'])):
    if predictions[0]['scores'][i] > 0.5:
        cv2.rectangle(output_image, (int(boxes[i][0]), int(boxes[i][1])), (int(boxes[i][2]), int(boxes[i][3])), (0, 255, 0), 2)
        for j in range(len(keypoints[i])):
            cv2.circle(output_image, (int(keypoints[i][j][0]), int(keypoints[i][j][1])), 2, (0, 0, 255), 2)
# imshow
plt.imshow(output_image)
plt.show()