In [2]:
import numbers
import os
import queue as Queue
import threading

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from prefetch_generator import BackgroundGenerator

In [4]:
import os
import time
import numpy as np

import albumentations as A
import albumentations.pytorch
import kornia as K
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

from set_loader import CustomDataset, CustomAlbDataset, CustomKorDataset

import cv2

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

In [3]:
class DataLoaderX(DataLoader):
    def __init__(self, local_rank, **kwargs):
        super(DataLoaderX, self).__init__(**kwargs)
        self.stream = torch.cuda.Stream(local_rank)
        self.local_rank = local_rank

    def __iter__(self):
        self.iter = super(DataLoaderX, self).__iter__()
        self.iter = BackgroundGenerator(self.iter, self.local_rank)
        self.preload()
        return self

    def preload(self):
        self.batch = next(self.iter, None)
        if self.batch is None:
            return None
        with torch.cuda.stream(self.stream):
            for k in range(len(self.batch)):
                self.batch[k] = self.batch[k].to(device=self.local_rank,
                                                 non_blocking=True)

    def __next__(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        if batch is None:
            raise StopIteration
        self.preload()
        return batch

In [5]:
p = 1.0
albumentations_transform = A.Compose([
    A.RandomCrop(256, 256, p=p),
    A.HorizontalFlip(p=p),
    A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), p=p),
    A.dropout.Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, p=p),
    A.pytorch.ToTensorV2(),
])



In [8]:
# jpeg4py + albumentation 
root_path = '/home/aiteam/tykim/dataset/afhq/train'
custom_ds = CustomAlbDataset(root_path, loader_type='jpeg4py', transform=albumentations_transform)
dataloader = DataLoaderX(dataset=custom_ds, batch_size=128, shuffle=False, num_workers=8, local_rank=0)

In [9]:
%%timeit -r 3 -n 3
simple_load_times = []
start_time = time.time()
for image, label in dataloader:
    image = image.cuda()
    label = label.cuda()
    pass
jpeg4py_alb_time = time.time() - start_time
simple_load_times.append(jpeg4py_alb_time)
print(str(simple_load_times) + ' sec')

[7.47588586807251] sec
[7.071631669998169] sec
[8.352059841156006] sec
[6.941926717758179] sec
[30.283140897750854] sec
[7.166435241699219] sec
[7.380136013031006] sec
[7.071781396865845] sec
[13.012507915496826] sec
10.5 s ± 3.08 s per loop (mean ± std. dev. of 3 runs, 3 loops each)
