In [None]:
TRAIN_IMG_DIR = "data/train/images"
TRAIN_LABEL_DIR = "data/train/labels"
VAL_IMG_DIR = "data/val/images"
VAL_LABEL_DIR = "data/val/labels"

NUM_KEYPOINTS = 11

In [None]:
# Check if all the images has the same resolution
import os
from PIL import Image
def check_images_resolution(folder_path):
    resolutions = set()
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            image_path = os.path.join(folder_path, filename)
            with Image.open(image_path) as img:
                resolutions.add(img.size)
    return len(resolutions) == 1, resolutions.pop() if resolutions else None

check_images_resolution('data/train/images')

In [None]:
# Check if all the labels has the same number of keypoints
def check_labels_keypoints(folder_path):
    keypoints_counts = set()
    for filename in os.listdir(folder_path):
        if filename.endswith('.txt'):
            label_path = os.path.join(folder_path, filename)
            with open(label_path, 'r') as f:
                lines = f.readlines()
                keypoints_counts.add(len(lines))
    return len(keypoints_counts) == 1, keypoints_counts.pop() if keypoints_counts else None

check_labels_keypoints('data/train/labels')

In [None]:
# Remove the labels with the wrong number of keypoints
def remove_wrong_labels(folder_path, expected_keypoints):
    for filename in os.listdir(folder_path):
        if filename.endswith('.txt'):
            label_path = os.path.join(folder_path, filename)
            with open(label_path, 'r') as f:
                lines = f.readlines()
                if len(lines) != expected_keypoints:
                    os.remove(label_path)
                    print(f"Removed {label_path} with {len(lines)} keypoints")

remove_wrong_labels('data/train/labels', expected_keypoints=11)

In [None]:
# remove from data/train/images those images that don't have a corresponding .txt file in data/train/labels
import os

def remove_unlabelled_images(image_dir, label_dir):
    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    label_files = [f for f in os.listdir(label_dir) if f.lower().endswith('.txt')]
    
    label_basenames = set(os.path.splitext(f)[0] for f in label_files)
    
    for image_file in image_files:
        image_basename = os.path.splitext(image_file)[0]
        if image_basename not in label_basenames:
            image_path = os.path.join(image_dir, image_file)
            print(f"Removing unlabelled image: {image_path}")
            os.remove(image_path)

remove_unlabelled_images('data/train/images', 'data/train/labels')

# Training pipeline for the ResNet

In [None]:
from resnet import ResNet18, KeypointDataset, train_model
from torch.utils.data import DataLoader
from torchvision import transforms

In [None]:
BATCH_SIZE = 16
NUM_WORKERS = 4
PIN_MEMORY = True
NUM_EPOCHS = 100
SAVE_INTERVAL = 5
PATIENCE = 0

TRAIN_LOSS_PATH = "train_loss.npy"
VAL_LOSS_PATH = "val_loss.npy"
LOSS_PLOT_PATH = "loss_plot.png"

In [None]:
# Create the model
model = ResNet18(num_keypoints=NUM_KEYPOINTS, pretrained=True)

In [None]:
# Define data transformations

train_transform = transforms.Compose([
    # TODO: For now, remove geometric transforms (they are only applied to images). We'll use Albumenations to also change labels
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)), # Random rotation and translation
    transforms.Resize((1920, 1080)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Keep validation transforms simple (no augmentation) to evaluate true performance
val_transform = transforms.Compose([
    transforms.Resize((1920, 1080)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Create datasets and dataloaders

train_dataset = KeypointDataset(TRAIN_IMG_DIR, TRAIN_LABEL_DIR, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

val_dataset = KeypointDataset(VAL_IMG_DIR, VAL_LABEL_DIR, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

In [None]:
# Train the model

train_loss, val_loss = train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS, save_interval=SAVE_INTERVAL, patience=PATIENCE)

In [None]:
# Save the lists of losses, accuracy and dice scores

import numpy as np
np.save(TRAIN_LOSS_PATH, np.array(train_loss))
np.save(VAL_LOSS_PATH, np.array(val_loss))

# To load the saved lists, use:
# train_loss = np.load("train_loss.npy")

In [None]:
# Plot the results

import matplotlib.pyplot as plt

# Plot the training and validation loss
plt.figure()
plt.title("Training and validation losses over epochs")
plt.plot(train_loss, label="Train loss")
plt.plot(val_loss, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(LOSS_PLOT_PATH)
plt.show()