In [2]:
classes = [
    "bedroom",
    "bridge",
    "church_outdoor",
    "classroom",
    "conference_room",
    "dining_room",
    "kitchen",
    "living_room",
    "restaurant",
    "tower",
]

In [3]:
import io
import os
import pickle

import lmdb
from PIL import Image
from torch.utils.data import Dataset


class LmdbDataset(Dataset):
    def __init__(self, path, max_readers=1, transform=None, max_images=None):
        super(LmdbDataset, self).__init__()
        self.path = path
        self.transform = transform
        self.max_readers = max_readers
        self.max_images = max_images

        env = self.init_env()
        with env.begin(write=False) as txn:
            cache_path = os.path.join(path, "cached_keys.pkl")
            if os.path.exists(cache_path):
                with open(cache_path, 'rb') as fp:
                    self.keys = pickle.load(fp)
            else:
                self.keys = [key for key, _ in txn.cursor()]
                with open(cache_path, 'wb') as fp:
                    pickle.dump(self.keys, fp)
            assert len(self.keys) == txn.stat()['entries']

    def init_env(self):
        return lmdb.open(self.path, max_readers=self.max_readers, readonly=True, lock=False, readahead=False,
                         meminit=False)

    def __getitem__(self, index):
        if not hasattr(self, 'env'):
            self.env = self.init_env()

        with self.env.begin(write=False) as txn:
            imgbuf = txn.get(self.keys[index])

        buf = io.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        img = Image.open(buf).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        if self.max_images is not None:
            return min(len(self.keys), self.max_images)
        return len(self.keys)

In [14]:
import PIL
import os
from tqdm import tqdm
from torchvision import transforms

# data_dir = "/ibex/scratch/skoroki/datasets/lsun"
# scenes_1M_dir = "/ibex/scratch/skoroki/datasets/lsun/scenes-1M"
data_dir = "/tmp/skoroki/data/lsun"
scenes_1M_dir = "/tmp/skoroki/data/lsun/scenes-1M"
NUM_IMGS = 100 * 1000
IMG_SIZE = 256

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE, interpolation=PIL.Image.LANCZOS),
    transforms.CenterCrop(IMG_SIZE),
])

for class_name in classes:
    source_dir = os.path.join(data_dir, f"{class_name}_train_lmdb")
    target_dir = os.path.join(scenes_1M_dir, class_name)
    os.makedirs(target_dir, exist_ok=True)
    dataset = LmdbDataset(source_dir, max_readers=16, transform=transform)
    
    assert len(dataset) >= NUM_IMGS, "Not sufficient amount of images"
    
    for i in tqdm(range(NUM_IMGS)):
        target_file = os.path.join(target_dir, f"{i:07d}.jpg")
        img = dataset[i]
        img.save(target_file)

100%|██████████| 100000/100000 [04:40<00:00, 356.00it/s]
100%|██████████| 100000/100000 [05:03<00:00, 330.00it/s]
100%|██████████| 100000/100000 [05:00<00:00, 332.37it/s]
100%|██████████| 100000/100000 [05:04<00:00, 328.68it/s]
100%|██████████| 100000/100000 [04:55<00:00, 338.85it/s]
100%|██████████| 100000/100000 [04:52<00:00, 341.43it/s]
100%|██████████| 100000/100000 [04:45<00:00, 350.83it/s]
100%|██████████| 100000/100000 [04:53<00:00, 341.10it/s]
100%|██████████| 100000/100000 [05:11<00:00, 320.89it/s]
100%|██████████| 100000/100000 [04:46<00:00, 348.76it/s]
