In [1]:
# import sys
# import os
# import csv

# import xml.etree.ElementTree as Et
# from xml.etree.ElementTree import Element, ElementTree
# from PIL import Image

# import json

# from xml.etree.ElementTree import dump

In [268]:
## https://deepbaksuvision.github.io/Modu_ObjectDetection/posts/03_01_dataloader.html

from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import glob
import random
import os
import warnings
import numpy as np
from PIL import Image
from PIL import ImageFile
import os.path as osp
import sys
import torch
import torchvision
import torch.utils.data as data
import cv2

from Format import YOLO as cvtYOLO
from Format import VOC as cvtVOC

ImageFile.LOAD_TRUNCATED_IMAGES = True

DATA_ROOT = osp.join("./", "data/fire/")

def pad_to_square(img, pad_value):
    c, h, w = img.shape
    dim_diff = np.abs(h - w)
    # (upper / left) padding and (lower / right) padding
    pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
    # Determine padding
    pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
    # Add padding
    img = F.pad(img, pad, "constant", value=pad_value)

    return img, pad

def resize(image, size):
    image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
    return image


class ImageFolder(Dataset):
    def __init__(self, folder_path, transform=None):
        self.files = sorted(glob.glob("%s/*.*" % folder_path))
        self.transform = transform

    def __getitem__(self, index):

        img_path = self.files[index % len(self.files)]
        img = np.array(
            Image.open(img_path).convert('RGB'),
            dtype=np.uint8)

        # Label Placeholder
        boxes = np.zeros((1, 5))

        # Apply transforms
        if self.transform:
            img, _ = self.transform((img, boxes))

        return img_path, img

    def __len__(self):
        return len(self.files)


class ListDataset(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, resize=448, class_path='./fire.name.txt'):

      self.root = root
      self.transform = transform
      self.target_transform = target_transform
      self.train = train
      self.resize_factor = resize
      self.class_path = class_path
      

      with open(class_path) as f:
        self.classes = f.read().splitlines()
  
      self.image_dir = osp.join(DATA_ROOT, 'imgs/')
      self.annopath_dir = osp.join(DATA_ROOT, 'annotations/')

      self.batch_count = 0

      self.data = self.cvtData()
      
      self.image_ids = os.listdir(self.image_dir)
      self.annotation_ids = os.listdir(self.annopath_dir)
      ##self.labeled_ids_len = len(self.labeled_ids)
    def __getitem__(self, index):
        img_id = self.image_ids[index]
        img_path = osp.join(self.image_dir,img_id)
        ##img = cv2.imread(img_path)
        img = Image.open(img_path).convert('RGB')
        
        try:
            key = list(self.data[index].keys())[0]
            target = self.data[index][key]
            semi = np.array([1])
            
        except:
            semi = np.array([0])
            target = np.zeros([1,5]) 
       
        print(img_path)
        current_size = img.size
        img = img.resize((self.resize_factor, self.resize_factor))
        img = torchvision.transforms.ToTensor()(img)
        # -----------
        #  Transform
        # -----------
        if self.transform is not None:
          img = self.transform(img)
          
        if self.target_transform is not None:
        # Future works
          pass
          
         
        return  img, target, semi, current_size 

    def __len__(self):
        return (len(self.image_ids))
    
    
    def cvtData(self):

      result = []
      voc = cvtVOC()

      yolo = cvtYOLO(os.path.abspath(self.class_path))
      flag, self.dict_data =voc.parse(os.path.join(self.annopath_dir))

      try:
        if flag:
          flag, data =yolo.generate(self.dict_data)

          keys = list(data.keys())
#           keys = sorted(keys, key=lambda key: int(key.split("_")[-1]))

          for key in keys:
            contents = list(filter(None, data[key].split("\n")))
            target = []
            for i in range(len(contents)):
              tmp = contents[i]
              tmp = tmp.split(" ")
              for j in range(len(tmp)):
                tmp[j] = float(tmp[j])
              target.append(tmp)
            result.append({key : target})

        return result

      except Exception as e:
        raise RuntimeError("Error : {}".format(e))
        

def detection_collate(batch):
    targets = []
    imgs = []
    sizes = []
    semis = []

    for sample in batch:
        imgs.append(sample[0])
        semis.append(torch.from_numpy(sample[2]))
        sizes.append(sample[3])
        
        np_label = np.zeros((7, 7, 6), dtype=np.float32)
        for object in sample[1]:
            objectness = 1
            classes = object[0]
            x_ratio = object[1]
            y_ratio = object[2]
            w_ratio = object[3]
            h_ratio = object[4]

            scale_factor = (1 / 7)
            grid_x_index = int(x_ratio // scale_factor)
            grid_y_index = int(y_ratio // scale_factor)
            x_offset = (x_ratio / scale_factor) - grid_x_index
            y_offset = (y_ratio / scale_factor) - grid_y_index

            np_label[grid_x_index][grid_y_index] = np.array([objectness, x_offset, y_offset, w_ratio, h_ratio, classes])

        label = torch.from_numpy(np_label)
        targets.append(label)
    torch_imgs = torch.stack(imgs, 0)
    torch_target = torch.stack(targets, 0)
    return torch_imgs, torch_target, semis, sizes

In [269]:
train_dataset = ListDataset(root = "./" )


VOC Parsing:  |----------------------------------------| 0.0% (0/53)  CompleteVOC Parsing:   |----------------------------------------| 1.9% (1/53)  CompleteVOC Parsing:   |█---------------------------------------| 3.8% (2/53)  CompleteVOC Parsing:   |██--------------------------------------| 5.7% (3/53)  CompleteVOC Parsing:   |███-------------------------------------| 7.5% (4/53)  CompleteVOC Parsing:   |███-------------------------------------| 9.4% (5/53)  CompleteVOC Parsing:   |████------------------------------------| 11.3% (6/53)  CompleteVOC Parsing:   |█████-----------------------------------| 13.2% (7/53)  CompleteVOC Parsing:   |██████----------------------------------| 15.1% (8/53)  CompleteVOC Parsing:   |██████----------------------------------| 17.0% (9/53)  CompleteVOC Parsing:   |███████---------------------------------| 18.9% (10/53)  CompleteVOC Parsing:   |████████--------------------------------| 20.8% (11/53)  CompleteVOC Parsing:   |███

In [270]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=4,
                                           shuffle=True,
                                           collate_fn=detection_collate)

In [271]:
iteration = iter(train_loader)

In [274]:
next(iteration)

./data/fire/imgs/[E] - [AUF GEGENFAHRBAHN GERATEN]  T철dlicher VU bei Tiefenbronn wegen unangepasster Geschwindigkeit.mp4_2250.jpg
./data/fire/imgs/Monroeville Volunteer Fire Department Co 4.mp4_4375.jpg
./data/fire/imgs/[E] - [AUF GEGENFAHRBAHN GERATEN]  T철dlicher VU bei Tiefenbronn wegen unangepasster Geschwindigkeit.mp4_250.jpg
./data/fire/imgs/Salva la vida gracias a los bomberos y vecinos.mp4_3000.jpg


(tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
           [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
           [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
           ...,
           [0.0235, 0.0235, 0.0196,  ..., 0.1765, 0.1647, 0.1686],
           [0.0235, 0.0235, 0.0196,  ..., 0.1725, 0.1725, 0.1725],
           [0.0235, 0.0235, 0.0196,  ..., 0.1843, 0.1529, 0.1725]],
 
          [[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
           [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
           [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
           ...,
           [0.0275, 0.0275, 0.0235,  ..., 0.1176, 0.1176, 0.1216],
           [0.0275, 0.0275, 0.0235,  ..., 0.1137, 0.1059, 0.1176],
           [0.0275, 0.0275, 0.0235,  ..., 0.1020, 0.1137, 0.1137]],
 
          [[0.0196, 0.0196, 0.0196,  ..., 0.0353, 0.0353, 0.0353],
           [0.0196, 0.0196, 0.0196,  ..., 0.0353, 0.0353, 0.0353],
           [0.0196, 0.01