# Imports

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

import json
import cv2
import numpy as np

In [None]:
# If there's a GPU available we'll use it otherwise we'll use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset

In [None]:
# Retrieves zipped dataset from google drive
!wget --header="Host: drive.usercontent.google.com" --header="User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7" --header="Accept-Language: en-US,en;q=0.9,ar;q=0.8" --header="Cookie: HSID=Ag2OIHvsd2Wub4C7z; SSID=AWnBcQKwDHiTrZAU1; APISID=pltrFZgE9lJ0o1gq/AN9feEHYvs8oHd519; SAPISID=zgF45F21ZPWzYWZw/AgUMJ8b7QQXuWGn19; __Secure-1PAPISID=zgF45F21ZPWzYWZw/AgUMJ8b7QQXuWGn19; __Secure-3PAPISID=zgF45F21ZPWzYWZw/AgUMJ8b7QQXuWGn19; SID=g.a000fwgYx1PcnW-rFyFhg3x6mQHzCrwXz-KFhoOLogUl7YTWI-uttBbVDRolhF-hY16nwHXw0gACgYKAWISAQASFQHGX2MivNTw_E_toJuIRy6LMpKNOBoVAUF8yKpFSmvq7AMjvEWeNc50Zff40076; __Secure-1PSID=g.a000fwgYx1PcnW-rFyFhg3x6mQHzCrwXz-KFhoOLogUl7YTWI-utbSY2jBY1VXuw8gYl5hIO2QACgYKAXsSAQASFQHGX2MihVCJ1PwLozGqZgdSatM9QhoVAUF8yKpgrsTvI8i_UE-YHpoN7Gx-0076; __Secure-3PSID=g.a000fwgYx1PcnW-rFyFhg3x6mQHzCrwXz-KFhoOLogUl7YTWI-utwVfPl2imdPimZJ9tdDZGQAACgYKAUESAQASFQHGX2MiEJ49mV4jME2kttDAV5hwWBoVAUF8yKp80mIgju1lu-q4nI7VsFDM0076; NID=511=efI9IZpxtyJ7Dw1MAUXU8FlzS5jXGewY4Er8HliWc3A0RSWdgvNDyKY66ETjgRyTGWPbWODSmiSeYSBab5SPHVwqbJxd6ZeGW2f6BkHi61UKksXPH0CVJRM1hKpMjHPU5qw7tboM2Mi87NrosV8COB-GCLulLLbjOoSAEQewTe8NVZ5Owq8IkwvxFGfJkmUKEMkFWrw9yb5nTDl3wbZEsGFI92iEdNTSxSRovNCIPN2US-SCFdQ0m2BtvwdiWZbgnn7dSQ8yPA145Kk2BA-ATpJNJ6SJHEHLQY-9CPail9D5qgJgxR925EUg5RGCpEu9wS5xbA62KTa19wAvbAq7Dk3TWc-iX4p1s7ESFyDC7yMpFxiFPJjqkWwFi_ZfiK2TW2t0TQ60DFBxqOytQaLyHrkEvD-CQPVj6OCOP22cZY0Cu61HaAQgFO9pXH-kJUlywzVdbirJumN5gswyaQ49b3KdLcG0Jb7brOMTM24T2nGtQ10hJzsnTwX7dBk3ujqQrI_DGuURvPassPUrIZ0; AEC=Ae3NU9MOEGeKAZjP6INpOYbyMraWAWztmx5pJB_1ILu1furiTy1K37k15u0; __Secure-1PSIDTS=sidts-CjEBYfD7Z9twEKTWJ9gU7KG-rLbxJGNRQIoG3wH6JVu6yiCC2fsRrm7tN8L6d5WlILrnEAA; __Secure-3PSIDTS=sidts-CjEBYfD7Z9twEKTWJ9gU7KG-rLbxJGNRQIoG3wH6JVu6yiCC2fsRrm7tN8L6d5WlILrnEAA; 1P_JAR=2024-02-18-08; SIDCC=ABTWhQExCxkfmwCkG1RaEgz8U1ZkPeh3HmLMUdMt8S5cNSsLY5U5rAL6wlvq7dtjRw7zrtAbqsFI; __Secure-1PSIDCC=ABTWhQH0jLeRIS6Tu3LS8DXB5Q3gGDq9LTmlk60FKu795Bf0UbzsOcYWVAE96clq5aAL8i724Q0; __Secure-3PSIDCC=ABTWhQHIFcyv3nZYwp78WXEQal71jCE_ZsGT5lXs8VLr7XDIfFqHcLTIPz4HxzJb9ZnYQ5l2s9eU" --header="Connection: keep-alive" "https://drive.usercontent.google.com/download?id=1lhAaeQCmk2y440PmagA0KmIVBIysVMwu&export=download&authuser=0&confirm=t&uuid=3077628e-fc9b-4ef2-8cde-b291040afb30&at=APZUnTU9lSikCSe3NqbxV5MVad5T%3A1708243355040" -c -O 'tennis_court_det_dataset.zip'

In [None]:
!unzip tennis_court_det_dataset.zip

In [None]:
class PointsDataset(Dataset):
    def __init__(self, img_dir, data_file):
        self.img_dir = img_dir
        # Reads dictionary from json file
        with open(data_file, "r") as f:
            self.data = json.load(f)

# Transforms data into right size
        self.transform = transforms.Compose(
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            # Normalise pixel values for easier training
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        )
    
    def __getItem__(self, idx):
        # Given an index, returns the image and the points
        item = self.data[idx]
        img = cv2.imread(f"{self.img_dir}/{item['id']}.png")
        height, width = img.shape[:2]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(img)
        # Array becomes 1D
        points = np.array(items['kps']).flatten()
        points = points.astype(np.float32)
        # Scale points to 224x224
        points[::2] *= 224 / width
        points[1::2] *= 224 / height

        return img, points

In [None]:
# Splitting data into training and validation
training_dataset = PointsDataset("data/images", "data/data_train.json")
validation_dataset = PointsDataset("data/images", "data/data_val.json")
training_loader = DataLoader(training_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=True)

# Model

In [None]:
# Load pre-trained model
model = models.resnet50(pretrained=True)
# Fine-tune model by changing the last layer to output x and y coordinates for the 14 keypoints
model.fc = torch.nn.Linear(model.fc.in_features, 14 * 2)
# Move model to device
model = model.to(device)

In [None]:
# Define loss function and optimiser
criterion = torch.nn.MSELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 20
for epoch in range(epochs):
    for i, (imgs, points) in enumerate(training_loader):
        imgs = imgs.to(device)
        points = points.to(device)
        # Zero gradients before each iteration
        optimiser.zero_grad()
        # Model outputs the predicted points
        outputs = model(imgs)
        # Calculate loss
        loss = criterion(outputs, points)
        # Backpropagation
        loss.backward()
        optimiser.step()

        if (i % 10) == 0:
            print(f"Epoch {epoch}, Iteration {i}, Loss: {loss.item()}")

In [None]:
# Save model
torch.save(model.state_dict(), "court_points_model.pth")