In [1]:
import torch
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

No GPU available. Training will run on CPU.


In [2]:
import os
import xml.etree.ElementTree as ET
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T
import torch


class RobotDataset(Dataset):
    def __init__(self, images_dir, xml_dir, transform=None):
        self.images_dir = images_dir
        self.xml_dir = xml_dir
        self.transform = transform
        self.image_files = [
            f for f in os.listdir(images_dir) if f.endswith((".jpg", ".png", ".jpeg"))
        ]

    def parse_xml(self, xml_path):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        boxes = []
        labels = []

        for obj in root.findall("object"):
            label = obj.find("pose").text + "_" + obj.find("name").text
            bbox = obj.find("bndbox")
            xmin = int(bbox.find("xmin").text)
            ymin = int(bbox.find("ymin").text)
            xmax = int(bbox.find("xmax").text)
            ymax = int(bbox.find("ymax").text)
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(label)

        return {"boxes": boxes, "labels": labels}

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        image_filename = self.image_files[idx]
        image_path = os.path.join(self.images_dir, image_filename)
        image = Image.open(image_path).convert("RGB")

        # Load and parse XML annotation
        xml_filename = os.path.splitext(image_filename)[0] + ".xml"
        xml_path = os.path.join(self.xml_dir, xml_filename)
        annotations = self.parse_xml(xml_path)

        if self.transform:
            image = self.transform(image)

        # Convert annotations to tensor-compatible format if necessary
        boxes = annotations["boxes"]
        labels = annotations["labels"]

        # Return a dictionary with the image and annotations
        return {
            "image": image,
            "boxes": boxes,
            "labels": labels,
        }

In [3]:
from pathlib import Path

images_path = Path("ATLAS_Dione_ObjectDetection/JPEGImages")
xml_path = Path("ATLAS_Dione_ObjectDetection/Annotations")
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])
dataset = RobotDataset(images_dir = images_path, xml_dir=xml_path, transform=transform)

In [4]:
from torch.utils.data import DataLoader


# This function avoids error when loading a batch with different sized labels lists
def collate_fn(batch):
    images = [item["image"] for item in batch]
    boxes = [item["boxes"] for item in batch]
    labels = [item["labels"] for item in batch]

    images = torch.stack(images, dim=0)

    return {"image": images, "boxes": boxes, "labels": labels}


# Define the DataLoader with the custom collate_fn
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [5]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])