In [None]:
import sys
import os

current_dir = os.getcwd()

parent_dir = os.path.dirname(current_dir)

sys.path.append(parent_dir)

from src.model import SSD
from src.loss import MultiBoxLoss
from src.utils import encode, decode

In [2]:
VOC_CLASSES = (
    'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor')

class_to_ind = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))

In [3]:
import os
import torch
import torch.utils.data as data
import numpy as np
import cv2
import xml.etree.ElementTree as ET

class VOCDetection(data.Dataset):
    def __init__(self, root, image_sets, transform=None, target_transform=None):
        """
        root: path to VOCdevkit (e.g., '/data/VOCdevkit/')
        image_sets: list of tuples, e.g., [('2007', 'trainval'), ('2012', 'trainval')]
        transform: functions to augment/normalize the image
        """
        self.root = root
        self.image_set_index = image_sets
        self.transform = transform
        self.target_transform = target_transform
        self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
        self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = list()
        
        #load the list of image IDs from the text files
        for (year, name) in image_sets:
            rootpath = os.path.join(self.root, 'VOC' + year)
            for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                self.ids.append((rootpath, line.strip()))

    def __getitem__(self, index):
        im, gt, h, w = self.pull_item(index)
        return im, gt

    def __len__(self):
        return len(self.ids)

    def pull_item(self, index):
        img_id = self.ids[index]
        
        # 1. Load Annotation
        target = ET.parse(self._annopath % img_id).getroot()
        
        # 2. Load Image
        img = cv2.imread(self._imgpath % img_id)
        height, width, channels = img.shape

        # 3. Parse Annotation to List
        res = []
        for obj in target.iter('object'):
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')
            
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1 # VOC is 1-based, convert to 0-based
                # Scale coordinates to 0-1 for SSD
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            
            label_idx = class_to_ind[name]
            bndbox.append(label_idx) # Format: [xmin, ymin, xmax, ymax, label_ind]
            res.append(bndbox)

        # 4. Apply Transforms (Augmentation + Resize)
        if self.transform is not None:
            img, boxes, labels = self.transform(img, res[:, :4], res[:, 4])
            res = np.hstack((boxes, np.expand_dims(labels, axis=1)))

        return torch.from_numpy(img).permute(2, 0, 1), res, height, width

In [4]:
def detection_collate(batch):
    """
    Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).
    
    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations
    
    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on 0 dim
    """
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(sample[0]) # The image tensor
        targets.append(torch.FloatTensor(sample[1])) # The annotation tensor
        
    return torch.stack(imgs, 0), targets

In [None]:
import os
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np
import xml.etree.ElementTree as ET

voc_root = 'VOCdevkit/'  
BATCH_SIZE = 4
IMG_SIZE = 300

VOC_CLASSES = (
    'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor')

class_to_ind = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
class BaseTransform:
    def __init__(self, size, mean):
        self.size = size
        self.mean = np.array(mean, dtype=np.float32)

    def __call__(self, image, boxes=None, labels=None):
        x = cv2.resize(image, (self.size, self.size)).astype(np.float32)
        x -= self.mean
        return x, boxes, labels

# --- 4. DATASET CLASS ---
class VOCDetection(data.Dataset):
    def __init__(self, root, image_sets, transform=None):
        self.root = root
        self.image_set_index = image_sets
        self.transform = transform
        self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
        self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = list()
        
        for (year, name) in image_sets:
            rootpath = os.path.join(self.root, 'VOC' + year)
            # Check if file exists to avoid immediate crash
            list_file = os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')
            if not os.path.exists(list_file):
                print(f"Warning: Could not find {list_file}")
                continue
                
            for line in open(list_file):
                self.ids.append((rootpath, line.strip()))

    def __getitem__(self, index):
        im, gt, h, w = self.pull_item(index)
        
        # separate boxes and labels
        boxes = [x[:4] for x in gt]
        labels = [x[4] for x in gt]
        
        # Apply Resize/Transform
        if self.transform is not None:
            im, boxes, labels = self.transform(im, boxes, labels)
            
        # Convert to Tensor (Channel First: 3, 300, 300)
        im = torch.from_numpy(im).permute(2, 0, 1)
        
        # keep them as a list of list because number of objects varies
        targets = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        
        return im, targets

    def __len__(self):
        return len(self.ids)

    def pull_item(self, index):
        img_id = self.ids[index]
        target = ET.parse(self._annopath % img_id).getroot()
        img = cv2.imread(self._imgpath % img_id)
        
        if img is None:
            raise FileNotFoundError(f"Could not find image: {self._imgpath % img_id}")

        height, width, channels = img.shape

        res = []
        for obj in target.iter('object'):
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')
            
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                # Scale to 0-1
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            
            label_idx = class_to_ind[name]
            bndbox.append(label_idx)
            res.append(bndbox)
            
        return img, res, height, width

# collate function only batch images and return targets as a list of tensors
def detection_collate(batch):
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(sample[0])
        targets.append(torch.FloatTensor(sample[1]))
    return torch.stack(imgs, 0), targets


if __name__ == '__main__':
    # transform
    transform = BaseTransform(IMG_SIZE, (104, 117, 123))

    dataset = VOCDetection(root=voc_root, 
                           image_sets=[('2007', 'trainval')], 
                           transform=transform)

    if len(dataset) == 0:
        print("ERROR: Dataset is empty. Check your 'voc_root' path at the top of the script.")
    else:
        # Initialize DataLoader
        # CRITICAL: num_workers=0 fixes the Windows RuntimeError
        data_loader = data.DataLoader(dataset, 
                                      batch_size=BATCH_SIZE, 
                                      num_workers=0,
                                      shuffle=True,
                                      collate_fn=detection_collate)

        print("Starting loop...")
        for batch_idx, (images, targets) in enumerate(data_loader):
            print(f"Batch {batch_idx} success!")
            print(f" - Image batch shape: {images.shape}") # Should be [4, 3, 300, 300]
            print(f" - Number of objects in first image: {len(targets[0])}")

            if batch_idx == 1:
                break

Starting loop...
Batch 0 success!
 - Image batch shape: torch.Size([4, 3, 300, 300])
 - Number of objects in first image: 6
Batch 1 success!
 - Image batch shape: torch.Size([4, 3, 300, 300])
 - Number of objects in first image: 1
