In [1]:
import torch
import cv2
import numpy as np
import os
import glob as glob

from xml.etree import ElementTree as et
from config import (
    CLASSES, RESIZE_TO, 
    TRAIN_DIR_IMAGES, VALID_DIR_IMAGES, 
    TRAIN_DIR_LABELS, VALID_DIR_LABELS,
    BATCH_SIZE
)
from torch.utils.data import Dataset, DataLoader
from custom_utils import collate_fn, get_train_transform, get_valid_transform

In [2]:
import ast, json

In [3]:
def read_dicts_from_file(file_path):
    lines = []
    with open(file_path, 'r') as file:
        lines = file.readlines()

    dictionary = ''
    for line in lines:
        line = line.replace('\n','').split(' ')
        line = [s for s in line if len(s) != 0]
        for s in line:
            dictionary = dictionary + s
            
    return ast.literal_eval(dictionary)

# Example usage
file_path = './input/train_txts/000032.txt'
dictionaries = read_dicts_from_file(file_path)

# Now 'dictionaries' is a list containing dictionaries from the file
print(dictionaries)

{'bboxes': [[[45.87416548610522, 321.0748212962833], [80.88455201672716, 418.18666836501575]], [[302.2567837910538, 306.2475239357931], [307.9105510864276, 331.9976051656832]], [[427.6693363071177, 306.5220569242455], [435.1162173731115, 333.624284056307]]], 'pedestrian_class': [9, 9, 9]}


In [4]:
class CustomDataset(Dataset):
    def __init__(
        self, images_path, labels_path, 
        width, height, classes, transforms=None
    ):
        self.transforms = transforms
        self.images_path = images_path
        self.labels_path = labels_path
        self.height = height
        self.width = width
        self.classes = classes
        self.image_file_types = ['*.jpg', '*.jpeg', '*.png', '*.ppm']
        self.all_image_paths = []
        self.dictionary = dict()
        
        # get all the image paths in sorted order
        for file_type in self.image_file_types:
            self.all_image_paths.extend(glob.glob(f"{self.images_path}/{file_type}"))
        self.all_annot_paths = glob.glob(f"{self.labels_path}/*.txt")
        
        # Remove all annotations and images when no object is present.
        self.read_and_clean()
        self.all_images = [image_path.split(os.path.sep)[-1] for image_path in self.all_image_paths]
        self.all_images = sorted(self.all_images)
    
    def read_and_clean(self):
        """
        This function will discard any images and labels when the XML 
        file does not contain any object.
        """
        for annot_path in self.all_annot_paths:

            self.dictionary = read_dicts_from_file(annot_path)
            bbox = self.dictionary['bboxes']
            
            if len(bbox) == 0:
                print(f"Removing {annot_path} and corresponding image")
                self.all_annot_paths.remove(annot_path)
                self.all_image_paths.remove(annot_path.split('.txt')[0]+'.jpg')

    def __getitem__(self, idx):
        
        # capture the image name and the full image path
        image_name = self.all_images[idx]
        image_path = os.path.join(self.images_path, image_name)

        # read the image
        image = cv2.imread(image_path)
        # convert BGR to RGB color format
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image_resized = cv2.resize(image, (self.width, self.height))
        image_resized /= 255.0
        
        # capture the corresponding XML file for getting the annotations
        annot_filename = image_name[:-4] + '.txt'
        annot_file_path = os.path.join(self.labels_path, annot_filename)
        
        boxes = []
        labels = []
        # tree = et.parse(annot_file_path)
        # root = tree.getroot()
        
        # get the height and width of the image
        image_width = image.shape[1]
        image_height = image.shape[0]

        dictionary = read_dicts_from_file(annot_file_path)
        bboxes = dictionary['bboxes']
        labels = dictionary['pedestrian_class']
        
        # box coordinates for xml files are extracted and corrected for image size given
        # for member in root.findall('object'):
        for bbox, label in zip(bboxes, labels):
            # map the current object name to `classes` list to get...
            # ... the label index and append to `labels` list
            bbox_ = bbox
            labels.append(label)
          
            
            # xmin = left corner x-coordinates
            xmin = bbox_[0][0]
            # xmax = right corner x-coordinates
            xmax = bbox_[1][0]
            # ymin = left corner y-coordinates
            ymin = bbox_[0][1]
            # ymax = right corner y-coordinates
            ymax = bbox_[1][1]
            
            # resize the bounding boxes according to the...
            # ... desired `width`, `height`
            xmin_final = (xmin/image_width)*self.width
            xmax_final = (xmax/image_width)*self.width
            ymin_final = (ymin/image_height)*self.height
            ymax_final = (ymax/image_height)*self.height
            
            boxes.append([xmin_final, ymin_final, xmax_final, ymax_final])

        # bounding box to tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # area of the bounding boxes
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # no crowd instances
        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
        # labels to tensor
        labels = torch.as_tensor(labels, dtype=torch.int64)

        # prepare the final `target` dictionary
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = area
        target["iscrowd"] = iscrowd
        image_id = torch.tensor([idx])
        target["image_id"] = image_id
        # apply the image transforms
        if self.transforms:
            sample = self.transforms(image=image_resized,
                                     bboxes=target['boxes'],
                                     labels=labels)
            image_resized = sample['image']
            target['boxes'] = torch.Tensor(sample['bboxes'])
            
        return image_resized, target

    def __len__(self):
        return len(self.all_images)

# prepare the final datasets and data loaders
def create_train_dataset():
    train_dataset = CustomDataset(
        TRAIN_DIR_IMAGES, TRAIN_DIR_LABELS,
        RESIZE_TO, RESIZE_TO, CLASSES, get_train_transform()
    )
    return train_dataset
def create_valid_dataset():
    valid_dataset = CustomDataset(
        VALID_DIR_IMAGES, VALID_DIR_LABELS, 
        RESIZE_TO, RESIZE_TO, CLASSES, get_valid_transform()
    )
    return valid_dataset

def create_train_loader(train_dataset, num_workers=0):
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    return train_loader
def create_valid_loader(valid_dataset, num_workers=0):
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    return valid_loader

In [7]:
dataset = CustomDataset(
    TRAIN_DIR_IMAGES, TRAIN_DIR_LABELS, RESIZE_TO, RESIZE_TO, CLASSES
)
print(f"Number of training images: {len(dataset)}")

# function to visualize a single sample
def visualize_sample(image, target):
    for box_num in range(len(target['boxes'])):
        box = target['boxes'][box_num]
        label = 'Pedestrian'#CLASSES[target['labels'][box_num]]
        cv2.rectangle(
            image, 
            (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
            (0, 255, 0), 2
        )
        cv2.putText(
            image, label, (int(box[0]), int(box[1]-5)), 
            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2
        )
    cv2.imshow('Image', image)
    cv2.waitKey(0)
    
NUM_SAMPLES_TO_VISUALIZE = 20
for i in range(NUM_SAMPLES_TO_VISUALIZE):
    image, target = dataset[i]
    visualize_sample(image, target)

Number of training images: 75
