In [1]:
%%capture
# connect google drive
from google.colab import drive
drive.mount('/content/drive')

# !pip install SoccerNet
!pip install lightning timm transformers torchmetrics

## Custom Dataset Model

In [None]:
from multiprocessing import process
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import pandas as pd
import os
from transformers import AutoImageProcessor
import pdb


class SoccerNetDataset(Dataset):
    """
    A dataset class for loading and preprocessing images from the SoccerNet dataset for object detection tasks.

    Attributes:
        root (str): The root directory of the dataset (e.g., 'data/tracking/train').
        processor (callable, optional): A processor for preprocessing the images.
        data (list): A list to store the images and their corresponding annotations.
        labelsToId (dict): A dictionary mapping class labels to their respective IDs.
    """
    def __init__(self, root, processor=None):
        """
        Initializes the SoccerNetDataset with the specified root directory and optional processor.

        Args:
            root (str): The root directory of the dataset.
            processor (callable, optional): A processor for preprocessing the images.
        """
        self.root = root
        self.processor = processor
        self.data = []
        self.labelsToId = {"player_team_left": 0, "player_team_right": 1, "ball": 2, "referee": 3, "goalkeeper_team_left": 4, "goalkeeper_team_right": 5, "other":6}
        self.id_to_label = {v: k for k, v in self.labelsToId.items()}
        for folder in os.listdir(root):
            if os.path.isdir(os.path.join(root, folder)):
                idToLabelLocal = self._parse_labels(os.path.join(root, folder, "gameinfo.ini"))
                img_folder = os.path.join(root, folder, "img1")
                gt = pd.read_csv(os.path.join(root, folder, "gt", "gt.txt"), header=None)
                gt.columns = ["frame", "class", "x", "y", "w", "h"] + [f"extra_{i}" for i in range(4)]
                annotations = {}
                for _, row in gt.iterrows():
                    imgName = f"{str(row['frame']).zfill(6)}.jpg"
                    # img = Image.open(os.path.join(img_folder, imgName))
                    label = idToLabelLocal[str(row["class"])]
                    # if annotations key is not present in annotations, add it
                    if imgName not in annotations:
                        annotations[imgName] = []
                    # do i need image_id in the annotations?
                    annotations[imgName].append({
                        "bbox": row[["x", "y", "w", "h"]].tolist(),
                        "bbox_mode": 0,
                        "category_id": self.labelsToId[label],
                        "iscrowd": 0,
                        "area" : row["w"] * row["h"]
                    })

                for imgName in os.listdir(img_folder):
                    image_id = int(folder.split('-')[1] + imgName.split('.')[0])
                    img_data = {"id": image_id,
                                "img": Image.open(os.path.join(img_folder, imgName))}
                    self.data.append((img_data, annotations[imgName]))
            # break


    def _parse_labels(self, filepath):
        """
        Parses the gameinfo.ini file to map class IDs to labels.

        Args:
            filepath (str): The path to the gameinfo.ini file.

        Returns:
            dict: A dictionary mapping class IDs to labels.
        """
        labels = {}
        with open(filepath, "r") as file:
            for line in file:
                if line.startswith("trackletID"):
                    parts = line.split("=")
                    class_id = parts[0].split("_")[1]
                    label = parts[1].split(";")[0]
                    labels[class_id] = label.strip().replace(" ", "_")
                    # bug in the labels, fix it
                    if labels[class_id] == "goalkeepers_team_left": labels[class_id] = "goalkeeper_team_left"
                    elif labels[class_id] == "goalkeepers_team_right": labels[class_id] = "goalkeeper_team_right"
        print(labels)
        return labels

    def __len__(self):
        """
        Returns the number of samples in the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns the image and corresponding annotations for the specified index.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the image and its annotations. If a processor is provided, the image is preprocessed before being returned.
            image is a tensor of shape (channels, height, width)
            annotations is a list of dictionaries containing the bounding box coordinates, category ID, and iscrowd flag for each object in the image
        """
        img_data, annotations = self.data[idx]

        # category_id is the index of the label in the list of labels
        target = {
            "image_id": img_data["id"],
            "annotations": annotations
        }
        if self.processor is None:
            return img_data["img"], target
        inputs = self.processor(images=img_data["img"], annotations=target, return_tensors="pt")
        pixel_values = inputs['pixel_values'].squeeze(0) # remove batch dimension
        labels = inputs['labels'][0] # remove batch dimension
        return pixel_values, labels



processor = AutoImageProcessor.from_pretrained('SenseTime/deformable-detr')
train_dataset_full = SoccerNetDataset("/content/drive/MyDrive/deformable-detr-soccer-analysis/data/tracking/train", processor=processor)
test_dataset = SoccerNetDataset("/content/drive/MyDrive/deformable-detr-soccer-analysis/data/tracking/test", processor=processor)



In [None]:
# # select 25% of train_dataset
# train_size = int(0.25 * len(train_dataset_full))
# train_dataset, _ = random_split(train_dataset_full, [train_size, len(train_dataset_full) - train_size])
# split the dataset into training and validation sets stratified by class

train_dataset = train_dataset_full
train_size = int(0.8 * len(train_dataset_full))
# val_size = len(train_dataset) - train_size
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# print train_dataset size and val_dataset size
print(f"train_dataset size: {len(train_dataset)}")
print(f"val_dataset size: {len(val_dataset)}")
print(f"test_dataset size: {len(test_dataset)}")

In [None]:
def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = processor.pad(pixel_values, return_tensors='pt')
    labels = [item[1] for item in batch]
    batch = {
            'pixel_values': encoding['pixel_values'],
            'pixel_mask': encoding['pixel_mask'],
            'labels': labels
        }
    return batch

# data loader for training and validation sets
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True,collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)


# visualize one image from the dataset with bounding boxes and labels
# also for each line of code, explain what it does
def view_labels(dataset, idx, predictions):
  import matplotlib.pyplot as plt
  import matplotlib.patches as patches
  if predictions is None:
    img, labels = train_dataset[idx]
    # why? because matplotlib expects channels last format but pytorch uses channels first format
    # meaning the image tensor has shape (channels, height, width) but matplotlib expects (height, width, channels)
    # so permute the dimensions to match the expected format
    plt.imshow(img.permute(1, 2, 0))
    ax = plt.gca() # why? to get the current axes of the plot to add patches to it later on for bounding boxes and labels in the image
    # axes are the subplots meaning the region of the image where the data is plotted
    # so to add bounding boxes and labels to the image, we need to get the current axes of the plot
    # so that we can add patches to it
    # plot the bounding boxes and labels
    for bbox, label in zip(labels["boxes"], labels["class_labels"]):
        # bbox is a tensor of shape (4,) containing the bounding box coordinates in (x, y, w, h) format and normalized to [0, 1] based on the image size
        # label is a tensor containing the class ID of the object
        # convert the bounding box coordinates to absolute values
        # convert bbox based on the image size
        bbox = [bbox[0]*img.shape[2], bbox[1]*img.shape[1], bbox[2]*img.shape[2], bbox[3]*img.shape[1]]
        #bbox[0] is center
        rect = patches.Rectangle(
            (bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], linewidth=1, edgecolor="r", facecolor="none"
        )
        ax.add_patch(rect)
        ax.text(bbox[0], bbox[1], f"{test_dataset.id_to_label[label.item()]}", color="red")
    plt.show()
  else:
    img, _ = train_dataset[idx]
    plt.imshow(img.permute(1, 2, 0))
    ax = plt.gca()
    for bbox, label in zip(predictions[idx]["boxes"], predictions[idx]["labels"]):
      bbox = [bbox[0]*img.shape[2], bbox[1]*img.shape[1], bbox[2]*img.shape[2], bbox[3]*img.shape[1]]
      rect = patches.Rectangle(
          (bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], linewidth=1, edgecolor="g", facecolor="none"
      )
      ax.add_patch(rect)
      ax.text(bbox[0], bbox[1], f"{test_dataset.id_to_label[label.item()]}", color="green")
    plt.show()

view_labels(train_dataset, 1, None)

# Model

In [None]:
import torch
from transformers import AutoImageProcessor, DeformableDetrForObjectDetection

class DeformableDetrForObjectDetectionModule(torch.nn.Module):
    """
    Deformable DETR model for object detection.

    Attributes:
        processor (AutoImageProcessor): A processor for preprocessing the images.
        model (DeformableDetrForObjectDetection): A Deformable DETR model for object detection
    """
    def __init__(self):
        super(DeformableDetrForObjectDetectionModule, self).__init__()
        self.processor = AutoImageProcessor.from_pretrained("SenseTime/deformable-detr")
        self.model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr", num_labels=len(test_dataset.labelsToId), ignore_mismatched_sizes=True)
        self.model.config.num_classes = len(test_dataset.labelsToId)


    def forward(self, batch):
        """
        Forward pass of the model.

        Args:
            images (tensor): The input images in the shape of (batch_size, channels, height, width).
            targets (list): The target annotations for the images in COCO format.
            Each target is a dictionary containing the following keys:
            - "image_id" (int): The ID of the image.
            - "annotations" (list): A list of dictionaries containing the bounding box coordinates, category ID, and iscrowd flag for each object in the image.
        """
        # return_tensors="pt" returns the processed images as PyTorch tensors
        # inputs = self.processor(images=images, annotations=targets, return_tensors="pt")
        # **inputs unpacks the dictionary into keyword arguments for the model which expects pixel_values and annotations
        # for example, if inputs = {"pixel_values": ..., "annotations": ...}, then **inputs is equivalent to model(pixel_values=..., annotations=...)
        outputs = self.model(**batch)
        return outputs

# Training

In [None]:
#import mean average precision
from networkx import number_attracting_components
from torchmetrics.detection import MeanAveragePrecision

torch.cuda.empty_cache()

class Trainer:
    """
    Trainer for training the Deformable DETR model for object detection.

    Attributes:
        module (DeformableDetrForObjectDetectionModule): The Deformable DETR model for object detection.
        optimizer (torch.optim.Adam): The Adam optimizer for training the model.
        criterion (callable): The loss function for training the model.
    """
    def __init__(self, module, optimizer, device):
        self.module = module
        self.optimizer = optimizer
        # self.criterion = criterion
        # self.compute_metric = compute_metric
        # move module to device
        self.device = device
        self.module.to(device)


    def train_epoch(self, train_loader):
        torch.cuda.empty_cache()
        self.module.train()
        total_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            # move batch to device
            # while pixel_values is a tensor, labels is a list of dictionaries for each image
            # each dictionary contains size, is_crowd, bbox of all, labels of all
            batch['pixel_values']  = batch['pixel_values'].to(self.device)
            batch['pixel_mask'] = batch['pixel_mask'].to(self.device)
            batch['labels'] = [{k: v.to(self.device) for k, v in t.items()} for t in batch['labels']]

            outputs = self.module(batch)

            self.optimizer.zero_grad()
            loss = outputs.loss
            loss.backward()
            self.optimizer.step()
            loss_dict = outputs.loss_dict
            total_loss += loss.item()

            print(f"batch: {batch_idx}, train_loss: {loss.item()}, train_loss_dict: {loss_dict}")


        avg_loss = total_loss / len(train_loader)
        return avg_loss

    def val_epoch(self, val_loader):
      torch.cuda.empty_cache()
      self.module.eval()
      with torch.no_grad():
        total_loss = 0
        metric = MeanAveragePrecision(box_format='cxcywh')
        all_preds = []
        all_targets = []
        for batch_idx, batch in enumerate(val_loader):
            # move batch to device
            batch['pixel_values']  = batch['pixel_values'].to(self.device)
            batch['pixel_mask'] = batch['pixel_mask'].to(self.device)
            batch['labels'] = [{k: v.to(self.device) for k, v in t.items()} for t in batch['labels']]


            outputs = self.module(batch)

            # Assume the existence of self.module.model.config.num_classes and outputs/logits, etc.

            # Extract loss and accumulate
            loss = outputs.loss
            loss_dict = outputs.loss_dict
            total_loss += loss.item()

            # Get the mask to remove the no-object class
            num_classes = self.module.model.config.num_classes
            outputs_labels = outputs.logits.argmax(-1)
            mask = (outputs_labels != num_classes)  # Shape: (batch_size, num_boxes)
            # get the probabilities
            scores = outputs.logits.softmax(-1)[..., :-1]
            scores = torch.sum(scores, dim=-1)

            # Initialize lists for preds and targets
            preds, targets = [], []

            # Iterate through the batch
            for i in range(outputs.pred_boxes.size(0)):
                # Apply the mask directly to filter boxes and labels
                filtered_boxes = outputs.pred_boxes[i][mask[i]]
                filtered_labels = outputs_labels[i][mask[i]]
                filtered_scores = scores[i][mask[i]]

                # Convert tensors to lists of dictionaries
                preds.append({
                    "boxes": filtered_boxes.cpu(),
                    "labels": filtered_labels.cpu(),
                    "scores": filtered_scores.cpu()
                })

            # # Process the targets (assuming batch["labels"] is a list of dictionaries)
            for img_labels in batch["labels"]:
                targets.append({"boxes": img_labels["boxes"].cpu(), "labels": img_labels["class_labels"].cpu()})

            # print('preds: {}'.format(preds))
            # print('targets: {}'.format(targets))
            # Update metric
            metric.update(preds, targets)

            print(f"batch: {batch_idx}, val_loss: {loss.item()}, val_loss_dict: {loss_dict}, map: {metric.compute()['map']}")
            print(f"preds: {preds}")
            all_preds.extend(preds)
            all_targets.extend(targets)
            break

        avg_loss = total_loss / len(val_loader)
        metric = metric.compute()
        return avg_loss, metric, all_preds

def main():
    batch_size = 8
    learning_rate = 1e-5
    num_epochs = 5

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    module = DeformableDetrForObjectDetectionModule()
    optimizer = torch.optim.AdamW(module.model.parameters(), lr=learning_rate)
    trainer = Trainer(module, optimizer, device)


    for epoch in range(num_epochs):
        train_loss = trainer.train_epoch(train_dataloader)
        val_loss, val_metrics, val_preds = trainer.val_epoch(val_dataloader)

        print(f'epoch: {epoch+1}/{num_epochs}, train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}, val_metric: {val_metrics}')
        view_labels(val_dataset, 1, val_preds)

    # test on test_dataset and visualize
    test_loss, test_metrics, all_preds = trainer.val_epoch(test_dataloader)
    print(f'test_loss: {test_loss:.4f}, test_metric: {test_metrics}')

    view_labels(test_dataset, 1, all_preds)

main()

# Training

# Evaluation on Test

# Visualization