# Training a model for detecting court key points 

## Download Dataset

In [None]:
# link to dataset
data_link = "https://drive.google.com/file/d/1lhAaeQCmk2y440PmagA0KmIVBIysVMwu/view"

## Import libraries

In [3]:
import torch
from torchvision import models,transforms
from torch.utils.data import Dataset,DataLoader

import numpy as np
import cv2
import json

In [7]:
# set device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Pre-process Data

### Create a Torch Dataset

In [5]:
class KeypointsDataset(Dataset):
    def __init__(self,img_dir,data_file):
        self.img_dir = img_dir

        with open(data_file, "r") as f:
            self.data = json.load(f)

        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        img = cv2.imread(f"{self.img_dir}/{item['id']}.png")
        h,w = img.shape[:2]

        # Cv2 uses BGR format as default we need to convert images into RGB format
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

        # Now trandform the image
        img = self.transforms(img)

        # get keypoints and convert them to a numpy array
        kps = np.array(item['kps']).flatten()
        kps = kps.astype(np.float32)

        # Adjust keypoint coordinates as we have resized the images
        kps[::2] *= 244.0/w    # Adjust X coordinates
        kps[1::2] *= 244.0/h   # Adjust Y coordinates

        return img,kps


In [None]:
# intialize our dataset class for train and validation dataset
train_dataset = KeypointsDataset("data/images","data/data_train.json")
val_dataset = KeypointsDataset("data/images","data/data_val.json")


### Create DataLoaders

In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

## Model Creation and Training

### Get a Pretrainied Model

In [None]:
model = models.resnet50(pretrained=True)
model.fc =  torch.nn.Linear(model.fc.in_features, 14*2) # Replaces the last layer
model = model.to(device)

### Setup training loop

In [None]:
# Setup loss function and optimizer
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
import tqdm
epochs = 20
for epoch in tqdm(range(epochs)):
    train_loss,valid_loss = 0,0
    ## Training
    for i,(imgs,kps) in enumerate(train_loader):

        imgs.to(device)
        kps.to(device)
        model.train()
        preds = model(imgs)
        loss = loss_fn(preds,kps)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader) # average loss per batch per epoch

    ## Validation
 
    model.eval()
    with torch.inference_mode():
        for i,(imgs,kps) in enumerate(val_loader):
            imgs.to(device)
            kps.to(device)
            preds = model(imgs)
            loss = loss_fn(preds,kps)
            valid_loss += loss.item()
        
        valid_loss /= len(val_loader)

    if (epoch % 2 == 0):
        print(f"Epoch: {epoch} | Train Loss: {train_loss:.2f} | Validation Loss: {valid_loss:.2f}")
    

### Save Model

In [None]:
torch.save(model.stat_dict(), "keypoints_model.pth")