In [None]:
from pathlib import Path
import sys
current_work_directionary = Path('__file__').absolute().parent
sys.path.insert(0, str(current_work_directionary))

from dataset import build_dataloader, build_test_dataloader
from tqdm import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
aug_hyp = {
            'data_aug_prespective_p': 1.0,
            'data_aug_scale': 0.,
            'data_aug_shear': 0,
            'data_aug_translate': 0.5,
            'data_aug_degree': 0,
            'data_aug_prespective': 1.0,
            'data_aug_hsv_p': 1,
            "data_aug_hsv_hgain": 0.015,
            "data_aug_hsv_sgain": 0.7,
            "data_aug_hsv_vgain": 0.4,
            'data_aug_mixup_p': 0.0,
            'data_aug_fliplr_p': 0,
            'data_aug_flipud_p': 0,
            'data_aug_fill_value': 128,
            'data_aug_mosaic_p': 1., 
            "data_aug_cutout_p": 1.0, 
            "data_aug_cutout_iou_thr": 0.2, 
            "data_aug_scale_jitting_p": 0.1, 
            'input_img_size': 448,
}

img_dir = '../../Dataset/COCO/val/img'
lab_dir = "../../Dataset/COCO/val/lab"
name_path = '../../Dataset/COCO/val/names.txt'
input_dim = [448, 448]
cache_num = 20
enable_data_aug = True
seed = 7
batch_size = 8
num_workers = 0
pin_memory = True
shuffle = True
drop_last = False


dataset, dataloader, prefetcher = build_dataloader(img_dir, lab_dir, name_path, input_dim, aug_hyp, cache_num, 
                                                    enable_data_aug, seed, batch_size, num_workers, pin_memory, shuffle, drop_last)

### dataset

In [None]:
img_id = np.random.randint(0, len(dataset))
x = dataset.pull_item(img_id)
ann = x[1]
img = x[0]
# img = img.permute(1, 2, 0)
# img = np.clip(img * 255.0, 0.0, 255.0)
img_mdy = np.ascontiguousarray(img.astype('uint8'))
h, w, _ = img.shape
save_path = current_work_directionary / "result" / "tmp" / f"batch_img_id_{img_id}.png"
dataset.cv2_save_fig(img_mdy, ann['bboxes'], ann['classes'], str(save_path))

### dataloader

In [None]:
with tqdm(total=len(dataloader), ncols=50) as t:
    for b, x in enumerate(dataloader):
        if b == 2:
            dataloader.close_data_aug()

        for i in range(batch_size):
            ann = x['ann'][i]
            title = x['img_id'][i]
            img = x['img'][i]
            img = img.permute(1, 2, 0)
            img = np.clip(img * 255.0, 0.0, 255.0)
            img_mdy = np.ascontiguousarray(img.numpy().astype('uint8'))
            h, w, _ = img.shape
            # 该笔数据中是否有object，ann[:, 4] == -1表示没有object
            valid_index = torch.nonzero(ann[:, 4] >= 0, as_tuple=False).squeeze(dim=1)
            # 如果该笔数据有object的话，就plot出来
            if valid_index.numel() > 0:
                ann_mdy = {'bboxes': ann[valid_index][:, :4].numpy(),
                            'classes': ann[valid_index][:, 4].numpy().astype('uint8')}
            # 如果该笔数据中没有发现object，则打印出图片的路径
            else:
                ann_mdy = {'bboxes': [], 'classes': []}
            save_path = current_work_directionary / "result" / "tmp" / f"batch_{b}_idx_{i}.png"
            dataset.cv2_save_fig(img_mdy, ann_mdy['bboxes'], ann_mdy['classes'], str(save_path))
            # print(f"{b*batch_size+i}\t{len(ann_mdy['bboxes'])}")
        
        if b >= 3:
            break
    t.update(batch_size)

  1%|▏            | 8/625 [00:07<10:07,  1.02it/s]


test dataloader

In [None]:
datadir = "./result/coco_test_imgs"
dataset, dataloader, prefetcher = build_test_dataloader(datadir, img_size=640, batch_size=2, num_workers=0)

with tqdm(total=len(dataloader), ncols=50) as t:
    for b, x in enumerate(dataloader):
        for i in range(len(x)):
            info = x['resize_info'][i]
            img = x['img'][i]
            img = img.permute(1, 2, 0)
            img = np.clip(img * 255.0, 0.0, 255.0)
            img_mdy = np.ascontiguousarray(img.numpy().astype('uint8'))
            h, w, _ = img.shape
            fig = plt.figure(figsize=[8, 8])
            plt.imshow(img_mdy)
            plt.show()
    t.update(batch_size)

In [38]:
import torch
a = torch.rand(4)
print(a, b.shape, b, sep='\n')

tensor([0.7524, 0.7114, 0.8233, 0.1799])
torch.Size([2, 2])
tensor([[0.0949, 0.9714],
        [0.9487, 0.6236]])


In [39]:
a.repeat(2, 2, 1)

tensor([[[0.7524, 0.7114, 0.8233, 0.1799],
         [0.7524, 0.7114, 0.8233, 0.1799]],

        [[0.7524, 0.7114, 0.8233, 0.1799],
         [0.7524, 0.7114, 0.8233, 0.1799]]])

In [35]:
c = torch.cat((b, b), dim=-1)
d = c > 0.5

In [36]:
d

tensor([[False,  True, False,  True],
        [ True,  True,  True,  True]])

In [37]:
d.all(dim=-1).any()

tensor(True)