# Example of train with a KeypointRCNN_ResNet50_FPN pre-trained model

In the code below, we are going to use the KeypointRCNN_ResNet50_FPN_Weights model pre-trained on the COCO dataset to train on our own dataset.

Import libraries

In [1]:
import json
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import kornia as K
import matplotlib.pyplot as plt
import torchvision
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
from tqdm import tqdm
import torch.optim as optim
from tqdm import tqdm
import datetime
import os
import torchmetrics


In [5]:
class CustomDataset(Dataset):
    def __init__(self, images_path, annotations_path, use_augmentation, device):
        self.images_path = images_path
        self.annotations_path = annotations_path
        self.device = device
        self.image_filenames = [filename for filename in os.listdir(images_path) if filename.endswith('.jpg')]

        if use_augmentation:
            # Declare an augmentation pipeline
            self.transform =K.augmentation.AugmentationSequential(
                K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]).to(device), torch.tensor([0.229, 0.224, 0.225]).to(device)),
                K.augmentation.RandomVerticalFlip(), 
                K.augmentation.RandomHorizontalFlip(),
                K.augmentation.RandomRotation(30),
                data_keys=["input", "bbox", "keypoints"])
        else:
            self.transform = K.augmentation.AugmentationSequential(
                K.augmentation.RandomRotation(0),
                data_keys=["input", "bbox", "keypoints"])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):

        image_filename = self.image_filenames[idx]
        image_tensor = self.load_image(os.path.join(self.images_path, image_filename))
        annotation_filename = os.path.join(self.annotations_path, image_filename[:-4] + '.json')

        with open(annotation_filename, 'r') as f:
            bbox_tensor, keypoint_tensor = self.convert_to_kornia_format(json.load(f))
    
        out_tensor = self.transform(image_tensor.float(), bbox_tensor.float(), keypoint_tensor.float())
        
        # plot augmentation to test it
        if False:
            img_out = self.plot_resulting_image(
            out_tensor[0][0],
            out_tensor[1].int(),
            out_tensor[2].int(),
            )
            # plot the image
            plt.imshow(K.tensor_to_image(image_tensor.mul(255).byte()).copy())
            plt.show()
            plt.imshow(img_out)
            plt.show()

        # get the torch format from kornia format
        target = self.kornia_to_torch_format( out_tensor[1], out_tensor[2])

        # return tensors
        return image_tensor, target

    def load_image(self, image_path: str)-> torch.Tensor:
        """
        Method to load image
        params:
            image_path: str = path of the image
        return:
            tensor: torch.Tensor = image tensor
        """

        # load image and convert to tensor
        image: np.ndarray = cv2.imread(image_path)

        # convert to tensor
        image_tensor: torch.Tensor = K.image_to_tensor(image)

        # bgr to rgb
        image_tensor = K.color.bgr_to_rgb(image_tensor)

        return K.enhance.normalize(image_tensor, torch.tensor(0.), torch.tensor(255.)).to(self.device)
    
    def convert_to_kornia_format(self, data):
        """
        Method to convert the bounding boxes and keypoints to the Kornia format
        params:
            data: dict = dictionary containing the bounding boxes and keypoints
        return:
            bbox_tensor: torch.Tensor = tensor containing the bounding boxes
            keypoint_tensor: torch.Tensor = tensor containing the keypoints
        """

        # Extract the bounding boxes and keypoints from the dictionary
        bboxes = data['bboxes']
        keypoints = data['keypoints']

        # Convert the bounding boxes to the Kornia format
        bbox_list = []
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            bbox_list.append([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
        bbox_tensor = torch.tensor(bbox_list).unsqueeze(0).to(self.device)

        # Convert the keypoints to the Kornia format
        keypoint_list = []
        for kpts in keypoints:
            for kpt in kpts:
                x, y, _ = kpt
                keypoint_list.append([x, y])
        keypoint_tensor = torch.tensor(keypoint_list).unsqueeze(0).to(self.device)
        return bbox_tensor, keypoint_tensor
    
    def plot_resulting_image(self, img, bbox, keypoints):
        """
        Plot the resulting image with bounding boxes and keypoints.
        params:
            img: torch.Tensor = image tensor
            bbox: torch.Tensor = bounding box tensor
            keypoints: torch.Tensor = keypoints tensor
        return:
            img_draw = image with bounding boxes and keypoints
        """
        img_array = K.tensor_to_image(img.mul(255).byte()).copy()
        img_draw = cv2.polylines(img_array, bbox.reshape(-1, 4, 2).cpu().numpy(), isClosed=True, color=(255, 0, 0))
        for k in keypoints[0]:
            img_draw = cv2.circle(img_draw, tuple(k.cpu().numpy()[:2]), radius=6, color=(255, 0, 0), thickness=-1)
        return img_draw

    def kornia_to_torch_format(self, bbox_tensor, keypoint_tensor, labels=None):
        """
        Convert bbox_tensor and keypoint_tensor in Kornia format to torch's expected format.
        
        Parameters:
        - bbox_tensor (torch.Tensor): Bounding box tensor in Kornia format
        - keypoint_tensor (torch.Tensor): Keypoint tensor in Kornia format
        - labels (list[int]): List of class labels for each bounding box. If None, default to label=1 for all boxes.
        
        Returns:
        - dict: A dictionary with 'boxes', 'labels', and 'keypoints' in the format expected by torch.
        """
        # Convert bbox_tensor from Kornia's format to torch's [x1, y1, x2, y2] format
        boxes = torch.stack([bbox_tensor[0,:,0,0], bbox_tensor[0,:,0,1], bbox_tensor[0,:,2,0], bbox_tensor[0,:,2,1]], dim=1)
        
        # If labels aren't provided, assume a default label of 1 for all bounding boxes
        if labels is None:
            labels = torch.ones((bbox_tensor.shape[1],), dtype=torch.int64).to(self.device)
        else:
            labels = torch.tensor(labels, dtype=torch.int64).to(self.device)
        
        # Convert keypoint_tensor to the desired [x, y, visibility] format
        keypoints = torch.zeros((bbox_tensor.shape[1], keypoint_tensor.shape[1]//bbox_tensor.shape[1], 3)).to(self.device)
        for i in range(bbox_tensor.shape[1]):
            keypoints[i, :, :2] = keypoint_tensor[0, i*2:(i+1)*2, :]
            keypoints[i, :, 2] = 1  # setting visibility to 1
        
        return {"boxes": boxes, "labels": labels, "keypoints": keypoints}

def collate_fn(batch):
    """
    Collate function to format the batch in the desired manner.
    
    Parameters:
    - batch (list): List of tuples where each tuple contains an image tensor and its associated target.
    
    Returns:
    - tuple: Tuple containing a tensor of images and a list of targets.
    """
    # Separate images and targets in the batch
    images, targets = zip(*batch)
    
    # # Stack images into a single tensor
    # images = torch.stack(images, 0)
    
    return images, targets


# if cuda is avaliable, use it
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

train_images_path = 'dataset/train/images'
train_annotations_path = 'dataset/train/annotations'
test_images_path = 'dataset/test/images'
test_annotations_path = 'dataset/test/annotations'

train_dataset = CustomDataset(train_images_path, train_annotations_path,use_augmentation = True, device = device)
test_dataset = CustomDataset(test_images_path, test_annotations_path,  use_augmentation = False, device = device)

train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size = 8, shuffle=False, collate_fn=collate_fn)

# Get a batch of images and annotations from the train dataloader
images, targets = next(iter(test_dataloader))

print(len(images))
print(len(targets))

cuda
8
8


# TRAIN

In [6]:
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
model.to(device).train()
print(model)

KeypointRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(640, 672, 704, 736, 768, 800), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.

In [7]:
# Define some hyperparameters
num_epochs = 100
lr = 0.001

# Create a directory with the current timestamp
now = datetime.datetime.now()
timestamp = now.strftime('%Y-%m-%d_%H-%M-%S')
save_dir = os.path.join("saved_models", timestamp)
os.makedirs(save_dir, exist_ok=True)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

best_loss = float('inf')  # Initialize with a high value

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    model.train()  # Set model to training mode
    
    running_loss = 0.0
    for images, targets in tqdm(train_dataloader):
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        loss_dict = model(images, targets)
        
        # Compute total loss
        losses = sum(loss for loss in loss_dict.values())
        
        # Backward pass and optimize
        losses.backward()
        optimizer.step()
        
        # Print statistics
        running_loss += losses.item()
    
    # Compute average loss for the epoch
    avg_train_loss = running_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Training Loss: {avg_train_loss}")

    # Validation phase
    val_loss = 0.0
    
    with torch.no_grad():
        for images, targets in tqdm(test_dataloader):

            # Forward pass
            loss_dict = model(images, targets)
            
            # Compute total loss
            losses = sum(loss for loss in loss_dict.values())

            # Accumulate validation loss
            val_loss += losses.item()

    # Compute average validation loss for the epoch
    avg_val_loss = val_loss / len(test_dataloader)
    print(f"Validation Loss: {avg_val_loss}")

    # Save the model if the validation loss improved
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        print(f"Improved validation loss at epoch {epoch+1}. Saving model...")
        torch.save(model.state_dict(), os.path.join(save_dir, f"model_best_epoch_{epoch+1}.pth"))
    
    # empty cuda cache
    torch.cuda.empty_cache()

print('Finished Training')


Epoch 1/100


100%|██████████| 14/14 [00:12<00:00,  1.11it/s]


Epoch [1/100] Training Loss: 10.680099964141846


100%|██████████| 3/3 [00:01<00:00,  1.98it/s]


Epoch [1/100] Validation Loss: 8.159403165181478
Improved validation loss at epoch 1. Saving model...
Epoch 2/100


100%|██████████| 14/14 [00:11<00:00,  1.22it/s]


Epoch [2/100] Training Loss: 8.259827273232597


100%|██████████| 3/3 [00:01<00:00,  2.01it/s]


Epoch [2/100] Validation Loss: 8.162141799926758
Epoch 3/100


100%|██████████| 14/14 [00:10<00:00,  1.34it/s]


Epoch [3/100] Training Loss: 8.227414608001709


100%|██████████| 3/3 [00:01<00:00,  1.93it/s]


Epoch [3/100] Validation Loss: 8.136830965677897
Improved validation loss at epoch 3. Saving model...
Epoch 4/100


100%|██████████| 14/14 [00:11<00:00,  1.21it/s]


Epoch [4/100] Training Loss: 8.241299492972237


100%|██████████| 3/3 [00:01<00:00,  1.92it/s]


Epoch [4/100] Validation Loss: 8.130262692769369
Improved validation loss at epoch 4. Saving model...
Epoch 5/100


100%|██████████| 14/14 [00:11<00:00,  1.24it/s]


Epoch [5/100] Training Loss: 8.246586799621582


100%|██████████| 3/3 [00:01<00:00,  1.93it/s]


Epoch [5/100] Validation Loss: 8.201896667480469
Epoch 6/100


100%|██████████| 14/14 [00:10<00:00,  1.30it/s]


Epoch [6/100] Training Loss: 8.21510832650321


100%|██████████| 3/3 [00:01<00:00,  1.92it/s]


Epoch [6/100] Validation Loss: 8.176630020141602
Epoch 7/100


100%|██████████| 14/14 [00:11<00:00,  1.21it/s]


Epoch [7/100] Training Loss: 8.337576729910714


100%|██████████| 3/3 [00:01<00:00,  1.91it/s]


Epoch [7/100] Validation Loss: 8.224010467529297
Epoch 8/100


100%|██████████| 14/14 [00:10<00:00,  1.29it/s]


Epoch [8/100] Training Loss: 8.276910645621163


100%|██████████| 3/3 [00:01<00:00,  1.90it/s]


Epoch [8/100] Validation Loss: 8.12479559580485
Improved validation loss at epoch 8. Saving model...
Epoch 9/100


100%|██████████| 14/14 [00:11<00:00,  1.23it/s]


Epoch [9/100] Training Loss: 8.269008840833392


100%|██████████| 3/3 [00:01<00:00,  1.94it/s]


Epoch [9/100] Validation Loss: 8.212945938110352
Epoch 10/100


 86%|████████▌ | 12/14 [00:09<00:01,  1.29it/s]


KeyboardInterrupt: 