In [1]:
import os
import cv2
import fire
import glob
import lmdb
import logging
import pyarrow
import lz4framed
import traceback
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import jpeg4py as jpeg
from itertools import tee
from typing import Generator, Any

In [2]:
logging.basicConfig(level=logging.INFO,
                    format= '[%(asctime)s] [%(pathname)s:%(lineno)d] %(levelname)s - %(message)s',
                    datefmt='%H:%M:%S')
logger = logging.getLogger(__name__)
DATA_DIRECTORY = '.././dataset/'
IMAGE_NAMES_FILE = 'image_names.csv'

In [89]:
os.path.exists(DATA_DIRECTORY)

True

In [94]:
lmdb_connection = lmdb.open("./data.lmdb", subdir=False,
                                map_size=int(2e11), readonly=False,
                                meminit=False, map_async=True)

In [88]:




def list_files_in_folder(folder_path: str) -> Generator:
    return (folder_path+file_name__str for file_name__str in os.listdir(folder_path))


def read_image_safely(image_file_name: str) -> np.array:
    try:
        return np.array(Image.open(image_file_name).convert("RGB")).astype(np.uint8)
    except Exception as e:
        return np.array([], dtype=np.uint8)


def serialize_and_compress(obj: Any):
    return lz4framed.compress(pyarrow.serialize(obj).to_buffer())


def extract_image_name(image_path: str) -> str:
    return image_path.split('/').pop(-1)


def resize(image_array, size=(256, 256)):
    if image_array.size == 0:
        return image_array
    return cv2.resize(image_array, dsize=size, interpolation=cv2.INTER_CUBIC)


def convert(image_folder: str, lmdb_output_path: str, write_freq: int=5000):
    assert os.path.isdir(image_folder), f"Image folder '{image_folder}' does not exist"
    assert not os.path.isfile(lmdb_output_path), f"LMDB store '{lmdb_output_path} already exists"
    assert not os.path.isdir(lmdb_output_path), f"LMDB store name should a file, found directory: {lmdb_output_path}"
    assert write_freq > 0, f"Write frequency should be a positive number, found {write_freq}"

    logger.info(f"Creating LMDB store: {lmdb_output_path}")

    image_file: Generator = list_files_in_folder(image_folder)
    image_file, image_file__iter_c1, image_file__iter_c2, image_file__iter_c3 = tee(image_file, 4)


    img_path_img_array__tuples = map(lambda tup: (tup[0], read_image_safely(tup[1])),
                                     zip(image_file__iter_c1, image_file__iter_c2))

    lmdb_connection = lmdb.open(lmdb_output_path, subdir=False,
                                map_size=int(1e11), readonly=False,
                                meminit=False, map_async=True)

    lmdb_txn = lmdb_connection.begin(write=True)
    total_records = 0

    try:
        for idx, (img_path, img_arr) in enumerate(tqdm(img_path_img_array__tuples)):
            img_idx: bytes = u"{}".format(idx).encode('ascii')
            img_name: str = extract_image_name(image_path=img_path)
            img_name: bytes = u"{}".format(img_name).encode('ascii')
            if idx < 5:
                print(idx,img_name, img_arr.size, img_arr.shape)
                logger.debug(img_idx, img_name, img_arr.size, img_arr.shape)
            lmdb_txn.put(img_idx, serialize_and_compress((img_name, img_arr.tobytes(), img_arr.shape)))
            total_records += 1
            if idx % write_freq == 0:
                lmdb_txn.commit()
                lmdb_txn = lmdb_connection.begin(write=True)
    except TypeError:
        print(traceback.format_exc())
        logger.error(traceback.format_exc())
        lmdb_connection.close()
        raise

    lmdb_txn.commit()

    logger.info("Finished writing image data. Total records: {}".format(total_records))

    logger.info("Writing store metadata")
    image_keys__list = [u'{}'.format(k).encode('ascii') for k in range(total_records)]
    with lmdb_connection.begin(write=True) as lmdb_txn:
        lmdb_txn.put(b'__keys__', serialize_and_compress(image_keys__list))

    logger.info("Flushing data buffers to disk")
    lmdb_connection.sync()
    lmdb_connection.close()
#     print(image_file_iter_c3)

    # -- store the order in which files were inserted into LMDB store -- #
    pd.Series(image_file__iter_c3).apply(extract_image_name).to_csv(os.path.join(DATA_DIRECTORY, IMAGE_NAMES_FILE),
                                                                    index=False, header=False)
    logger.info("Finished creating LMDB store")


In [9]:
datadir=".././dataset/256/"

In [10]:
convert(datadir,DATA_DIRECTORY+"lmdb-store.db")

[03:20:59] [<ipython-input-8-01c279a695e3>:32] INFO - Creating LMDB store: .././dataset/lmdb-store.db
0it [00:00, ?it/s]

0 b'000002b66c9c498e_resized.jpg' 196608 (256, 256, 3)


1it [00:00,  6.29it/s]

1 b'000002b97e5471a0_resized.jpg' 196608 (256, 256, 3)
2 b'000002c707c9895e_resized.jpg' 196608 (256, 256, 3)
3 b'0000048549557964_resized.jpg' 196608 (256, 256, 3)
4 b'000004f4400f6ec5_resized.jpg' 196608 (256, 256, 3)


241978it [42:56, 93.93it/s] 


MapFullError: mdb_put: MDB_MAP_FULL: Environment mapsize limit reached

In [40]:
# lmdbloader.py

import os
import lmdb
import pyarrow
import lz4framed
import numpy as np
from typing import Any
import nonechucks as nc
from torch.utils.data import Dataset, DataLoader


class InvalidFileException(Exception):
    pass


class LMDBDataset(Dataset):
    def __init__(self, lmdb_store_path, transform=None):
        super().__init__()
        assert os.path.isfile(lmdb_store_path), f"LMDB store '{lmdb_store_path} does not exist"
        assert not os.path.isdir(lmdb_store_path), f"LMDB store name should a file, found directory: {lmdb_store_path}"

        self.lmdb_store_path = lmdb_store_path
        self.lmdb_connection = lmdb.open(lmdb_store_path,
                                         subdir=False, readonly=True, lock=False, readahead=False, meminit=False)

        with self.lmdb_connection.begin(write=False) as lmdb_txn:
            self.length = lmdb_txn.stat()['entries'] - 1
            self.keys = pyarrow.deserialize(lz4framed.decompress(lmdb_txn.get(b'__keys__')))
            print(f"Total records: {len(self.keys), self.length}")
        self.transform = transform
        

    def __getitem__(self, index):
        lmdb_value = None
        with self.lmdb_connection.begin(write=False) as txn:
            lmdb_value = txn.get(self.keys[index])
        assert lmdb_value is not None, f"Read empty record for key: {self.keys[index]}"

        img_name, img_arr, img_shape = LMDBDataset.decompress_and_deserialize(lmdb_value=lmdb_value)
        image = np.frombuffer(img_arr, dtype=np.uint8).reshape(img_shape)
        if image.size == 0:
            raise InvalidFileException("Invalid file found, skipping")
        return image

    @staticmethod
    def decompress_and_deserialize(lmdb_value: Any):
        return pyarrow.deserialize(lz4framed.decompress(lmdb_value))

    def __len__(self):
        return self.length







In [3]:
dataset = LMDBDataset(DATA_DIRECTORY+"lmdb-store.db")

In [51]:
data_loader = DataLoader(dataset, shuffle=True, batch_size=16, num_workers=1, pin_memory=False)

In [1]:
# dataset[100]

In [2]:
for _ in range(1):
    for batch in data_loader:
#         pass
        assert len(batch) > 0

In [None]:
if __n
    dataset = nc.SafeDataset(LMDBDataset('./data/lmdb-tmp.db'))
    batch_size = 64
    data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4, pin_memory=False)
    n_epochs = 50

    for _ in range(n_epochs):
        for batch in data_loader:
            assert len(batch) > 0