First we need to download the data and unpack the zip file. The dataset is very large, so downloading takes quite some time.

In [2]:
import os
import zipfile


if not os.path.exists('data'):#check if we have already unpacked the data
    if not os.path.exists('HollywoodHeads.zip'):
        print('Downloading dataset, this might take a while')
        !wget https://www.di.ens.fr/willow/research/headdetection/release/HollywoodHeads.zip
    with zipfile.ZipFile('HollywoodHeads.zip') as file:
        print('Unzipping dataset')
        file.extractall()
        #!mv HollywoodHeads data
    os.rename('HollywoodHeads','data')
    !rm HollywoodHeads.zip

Downloading dataset, this might take a while


--2023-10-28 17:46:09--  https://www.di.ens.fr/willow/research/headdetection/release/HollywoodHeads.zip
Herleiden van www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Verbinding maken met www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... verbonden.
HTTP-verzoek is verzonden; wachten op antwoord... 200 OK
Lengte: niet-opgegeven [application/zip]
Wordt opgeslagen als: ‘HollywoodHeads.zip’

HollywoodHeads.zip      [  <=>               ]   5,36G  5,88MB/s    in 16m 15s 

2023-10-28 18:02:24 (5,63 MB/s) - '‘HollywoodHeads.zip’' opgeslagen [5755744131]

Unzipping dataset


In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from xml.etree import ElementTree as et

class HollywoodHeadDataset(Dataset):
    def __init__(self, root, transforms=None, mode='train') -> None:
        super().__init__()
        self.transforms = transforms
        self.root = root

        filename = mode.lower() + '.txt'
        filepath = os.path.join(root,'Splits',filename)

        with open(filepath,'r') as f:
            img_names = f.readlines()
        self.imgs = [img.strip('\n') for img in img_names]

        self.imgs_dir = os.path.join(root, 'JPEGImages')
        self.annot_dir = os.path.join(root, 'Annotations')
        #self.classes = ['background','head']

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_filename = self.imgs[idx]+'.jpeg'
        image_path = os.path.join(self.imgs_dir,img_filename)

        annot_filename = self.imgs[idx]+'.xml'
        annot_file_path = os.path.join(self.annot_dir,annot_filename)

        img = torchvision.io.read_image(image_path, ImageReadMode.RGB)        
        
        boxes=[]
        #labels=[]
        tree = et.parse(annot_file_path)
        root = tree.getroot()
        for object in root.findall('object'):
            labels.append(self.classes.index(object.find('name').text))
            xmin=int(object.find('bndbox').find('xmin').text)
            xmax=int(object.find('bndbox').find('xmin').text)

            ymin=int(object.find('bndbox').find('ymin').text)
            ymax=int(object.find('bndbox').find('ymax').text)

            boxes.append([xmin,xmax,ymin,ymax])

        area = (boxes[:,2]-boxes[:,1])*(boxes[:,4]-boxes[:,3])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        iscrowd = torch.zeros(boxes.shape[0],dtype=torch.int64)
        
        #We only have one class
        labels=torch.ones(boxes.shape[0], dtype=torch.int64)

        

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = area
        target["iscrowd"] = iscrowd
        target["image_id"] = torch.tensor([idx])

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target