In [22]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

import json
import cv2
import numpy as np

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
#Create Torch Dataset

In [58]:
import json
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms

import os
import json
import cv2
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class keypointsDataset(Dataset):
    def __init__(self, img_dir, data_file):
        self.img_dir = img_dir
        with open(data_file, "r") as f:
            self.data = json.load(f)

        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = f"{self.img_dir}/{item['id']}.png"

        # Check if the file exists
        if not os.path.exists(img_path):
            print(f"Warning: Image {item['id']}.png not found in {self.img_dir}, skipping...")
            return None, None  # Return None for both image and keypoints

        img = cv2.imread(img_path)
        if img is None:
            print(f"Warning: Failed to load image {item['id']}.png, skipping...")
            return None, None

        h, w = img.shape[:2]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(img)

        kps = np.array(item['kps']).flatten().astype(np.float32)
        kps[::2] *= 224.0 / w  # Adjust x coordinates
        kps[1::2] *= 224.0 / h  # Adjust y coordinates

        return img, kps



In [60]:
train_dataset = keypointsDataset("D:/projects done/SpinSense-/tennis_court_det_dataset/data/images", "D:/projects done/SpinSense-/tennis_court_det_dataset/data/data_train.json")
val_dataset = keypointsDataset("D:/projects done/SpinSense-/tennis_court_det_dataset/data/images", "D:/projects done/SpinSense-/tennis_court_det_dataset/data/data_val.json")

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)



In [61]:
#Create Model
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 14 * 2 ) #Replace the last layer

In [62]:
model = model.to(device)

In [53]:
#Train the model

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [67]:
def custom_collate(batch):
    batch = [(img, kp) for img, kp in batch if img is not None and kp is not None]
    if not batch:
        return None, None
    imgs, kps = zip(*batch)
    return torch.stack(imgs), torch.tensor(kps)


In [70]:
train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=custom_collate)


In [72]:
import logging


for imgs, kps in train_loader:
    if imgs is None or kps is None:
        logging.warning("Empty batch, skipping...")
        continue

    imgs, kps = imgs.to(device), kps.to(device)
    ...

# Configure logging
logging.basicConfig(filename="training.log", level=logging.WARNING, format="%(asctime)s - %(message)s")

epochs = 10
for epoch in range(epochs):
    for i, (imgs, kps) in enumerate(train_loader):
        try:
            # Filter out invalid entries
            valid_data = [(img, kp) for img, kp in zip(imgs, kps) if img is not None and kp is not None]

            if not valid_data:
                logging.warning(f"Epoch {epoch}, iter {i}: No valid data in batch, skipping...")
                continue

            # Unpack and prepare data
            imgs, kps = zip(*valid_data)
            imgs = torch.stack(imgs).to(device)
            kps = torch.tensor(kps).to(device)

            # Forward and backward pass
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, kps)
            loss.backward()
            optimizer.step()

            # Log loss every 10 iterations
            if i % 10 == 0:
                print(f"Epoch {epoch}, iter {i}, loss {loss.item()}")

        except Exception as e:
            logging.error(f"Error during training at epoch {epoch}, iter {i}: {e}")
            continue





In [77]:
torch.save(model.state_dict(), "keypoints_model.pth")