In [8]:
import numpy as np
import cv2
from utils import CVTransform, xywh2xyxy, xyxy2xywh, letterbox_resize, iou_general, plot
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from V2.VOC.utils import parse_anchors
from pathlib import Path
from utils import traverse_voc
import pickle
from skimage import io

ModuleNotFoundError: No module named 'utils'

In [None]:
class VOC2007:
    def __init__(self, opt):
        self.label_names = opt.VOC_BBOX_LABEL_NAMES
        self.label_names_dict = {name: index for index, name in enumerate(self.label_names)}
        self.ann_dir = Path(opt.data_dir) / 'Annotations'
        self.img_dir = Path(opt.data_dir) / 'JPEGImages'
        self.obj_dict_path = opt.obj_path
        if not Path(self.obj_dict_path).exists():
            traverse_voc(self.ann_dir, self.obj_dict_path)
        self.obj_dicts = pickle.load(open(self.obj_dict_path, 'rb'))
        self.filenames = [_ for _ in self.obj_dicts.keys()]

    def __len__(self):
        return len(self.filenames)

    def get_example(self, idx):
        filename = self.filenames[idx]
        obj_dict = self.obj_dicts[filename]
        obj_boxes = obj_dict['boxes']
        obj_names = obj_dict['names']
        obj_labels = [self.label_names_dict[name] for name in obj_names]
        img_path = self.img_dir / f'{filename}'
        img = io.imread(img_path)
        return img, np.array(obj_labels), np.asarray(obj_boxes)


def pytorch_normailze(img, mean, std):
    torch_normailze = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
    img = torch_normailze(img)
    return img


class VOC2007Dataset(Dataset):
    """
    :return
    training:
        1.img:(batch_size,3,448,448)/tensor
        2.gt_bbox:(batch_size,-1,4)/tensor
        3.gt_label:(batch_size,-1)/ndarray
        4.scale:(batch_size,1,2)/ndarray
        5.y_true['target']:(13,13,5,25)/tensor
    """

    def __init__(self, opt):
        self.opt = opt
        self.database = VOC2007(opt)
        # anchor's scale is 416 / shape: [5, 2]
        self.anchor_base = parse_anchors(opt.anchors_path, opt)
        self.image_aug = CVTransform(1)

    def __len__(self):
        return len(self.database)

    def __getitem__(self, index):
        img, labels, boxes = self.database.get_example(index)
        img, labels, boxes = self.image_aug(img, boxes, labels, 'RGB')
        resized_img, resized_boxes = letterbox_resize(img, boxes, [self.opt.img_h, self.opt.img_w])
        target = self.make_target(resized_boxes, labels, self.opt.img_size)
        img_norm = pytorch_normailze(resized_img, self.opt.mean, self.opt.std)
        return img_norm, target