In [2]:
import os
import lmdb
import torchvision
import multiprocessing

from torch import nn
from tqdm import tqdm
from PIL import Image
from io import BytesIO
from functools import partial
from torch.utils.data import Dataset
from torchvision.transforms import functional as trans_fn

In [2]:
data_dir = "/home/pervinco/Datasets/CelebA"
lmdb_save_dir = "/home/pervinco/Datasets/CelebA"
num_worker = 8

In [3]:
imgset = torchvision.datasets.ImageFolder(data_dir)

In [4]:
def resize_and_convert(img, size, quality=100):
    img = trans_fn.resize(img, size, Image.LANCZOS)
    img = trans_fn.center_crop(img, size)
    buffer = BytesIO()
    img.save(buffer, format='jpeg', quality=quality)
    val = buffer.getvalue()

    return val


def resize_multiple(img, sizes=(8, 16, 32, 64, 128, 256, 512, 1024), quality=100):
    imgs = []

    for size in sizes:
        imgs.append(resize_and_convert(img, size, quality))

    return imgs


def resize_worker(img_file, sizes):
    i, file = img_file
    img = Image.open(file)
    img = img.convert('RGB')
    out = resize_multiple(img, sizes=sizes)

    return i, out

In [5]:
## sizes=(8, 16, 32, 64, 128, 256, 512, 1024)
def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512)):
    ##  resize_worker 함수의 sizes 인자를 사전에 설정
    resize_fn = partial(resize_worker, sizes=sizes)

    files = sorted(dataset.imgs, key=lambda x: x[0]) ## (img_file_path, label)
    files = [(i, file) for i, (file, label) in enumerate(files)] ## (idx, img_file_path)
    
    total = 0
    with multiprocessing.Pool(n_worker) as pool: ## Multi Processing
        ## pool.imap_unordered 비동기적으로 여러 개의 프로세스들이 resize_fn을 수행함.
        ## 즉, 처리가 완료되는 대로 결과를 반환하기 때문에, 작업 완료 순서는 실행마다 달라질 수 있다.
        for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
            for size, img in zip(sizes, imgs):
                key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
                transaction.put(key, img)

            total += 1

        transaction.put('length'.encode('utf-8'), str(total).encode('utf-8'))

In [6]:
"""
lmdb_save_dir에 lmdb 데이터베이스를 생성.
 - map_siez는 데이터베이스의 최대 크기.(1024 ** 4 = 1TB)
 - readhead는 사전 읽기 기능을 활성화, 비활성화할 것인가
"""
if not os.path.exists(f"{lmdb_save_dir}/data.mdb") and not os.path.exists(f"{lmdb_save_dir}/lock.mdb"):
    with lmdb.open(lmdb_save_dir, map_size=1024 ** 4, readahead=False) as env:
        with env.begin(write=True) as txn: ## lmdb에 쓰기 작업을 시작.
            prepare(txn, imgset, num_worker)

In [None]:
class MultiResolutionDataset(Dataset):
    def __init__(self, path, transform, resolution=8):
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

        self.resolution = resolution
        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img = self.transform(img)

        return img