In [1]:
import torch
import os
import xml.etree.ElementTree as ET

from PIL import Image
from torch.utils.data import Dataset
from utilities import transform

In [2]:
# version 1.0

labels_idx = ["D00","D10","D20","D40"]

class RoadDamageDataset(Dataset):
    def __init__(self, folder):
        
        self.dir_xmls = os.path.join(folder, "annotations", "xmls")
        self.dir_imgs = os.path.join(folder, "images")
        
        self.annotations = [os.path.join(self.dir_xmls, xml) for xml in os.listdir(self.dir_xmls)]
        self.images_file = [os.path.join(self.dir_imgs, img) for img in os.listdir(self.dir_imgs)]

    def __getitem__(self,i):
        
        image = Image.open(self.images_file[i])
        tree = ET.parse(self.annotations[i])
        objects = tree.findall("object")

        bboxes = []
        labels = []
        
        for i, obj in enumerate(objects):
            objectID = "object_{:02d}".format(i)
            name = obj.find("name").text
            if name == 'Repair':
                continue
            bbox = obj.find("bndbox")
            xmin = int(bbox.find("xmin").text)
            ymin = int(bbox.find("ymin").text)
            xmax = int(bbox.find("xmax").text)
            ymax = int(bbox.find("ymax").text)
            
            label = labels_idx.index(name)
            labels.append(label)
            bboxes.append([xmin, ymin, xmax, ymax])
        
        bboxes = torch.FloatTensor(bboxes)
        labels = torch.FloatTensor(labels)
        
        image, bboxes, labels = transform(image, bboxes, labels)
        
        return image, bboxes, labels
        
    

In [3]:
train_folder ="D:\\Dataset\\CRDDC2022\\dataset\\China_MotorBike\\train\\"

train_RDD = RoadDamageDataset(train_folder)

In [4]:
image, bboxes, labels = train_RDD[0]

In [6]:
print(image)
print(bboxes)
print(labels)

tensor([[[ 1.5982,  1.5810,  1.5297,  ...,  0.6734,  0.5536,  0.4508],
         [ 1.5125,  1.5125,  1.4783,  ...,  0.5707,  0.6049,  0.4679],
         [ 1.4783,  1.4612,  1.4098,  ...,  0.6563,  0.5364,  0.4508],
         ...,
         [ 0.3994,  0.5536,  0.4166,  ...,  0.1597,  0.1083,  0.2796],
         [ 0.5536,  0.0056,  0.3652,  ..., -0.1828,  0.2624,  0.0398],
         [ 0.4851,  0.0398,  0.5364,  ...,  0.3138,  0.4851,  0.3309]],

        [[ 1.7633,  1.7633,  1.6933,  ...,  0.8179,  0.6954,  0.5903],
         [ 1.6933,  1.6933,  1.6408,  ...,  0.7129,  0.7479,  0.6078],
         [ 1.6583,  1.6408,  1.5882,  ...,  0.8004,  0.6779,  0.5903],
         ...,
         [ 0.5028,  0.6779,  0.5028,  ...,  0.2927,  0.2402,  0.4153],
         [ 0.6429,  0.0826,  0.4678,  ..., -0.0574,  0.3978,  0.1702],
         [ 0.6078,  0.1176,  0.6429,  ...,  0.4503,  0.6254,  0.4678]],

        [[ 1.9080,  1.9080,  1.8383,  ...,  1.0539,  0.9494,  0.8448],
         [ 1.8383,  1.8383,  1.7860,  ...,  0