Skip to content

Commit

Permalink
allow customzied decoding func; don't assume __len__
Browse files Browse the repository at this point in the history
  • Loading branch information
w-hc committed Aug 1, 2020
1 parent 7781ecc commit ce24638
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions fabric/io/lmdb_tools.py
Expand Up @@ -30,7 +30,7 @@
from dataflow.utils import logger # TODO: add a consistent logger for fabric itself

dumps = lambda x: pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
loads = pickle.loads
# loads = pickle.loads

__all__ = ['save_to_lmdb', 'LMDBData', 'ImageLMDB']

Expand Down Expand Up @@ -123,13 +123,33 @@ class LMDBData():
https://github.com/pytorch/vision/issues/689 provides solution on how to
deal with the un-picklable db Environment.
"""
def __init__(self, db_fname, readahead=False):
def __init__(self, db_fname, readahead=False, decoding_func=pickle.loads):
self.db_fname = str(db_fname)
self.readahead = readahead
self.loads = decoding_func
# disabling readahead improves random read performance

self.read_txn = self.make_read_transaction()
self.length = self._retrieve_item(b'__len__')

# attempt to retrieve stored db size
try:
length = self._retrieve_item(b'__len__')
except KeyError as e:
print(f"{e}")
length = None

self.length = length

def __len__(self):
if self.length is None:
raise ValueError("db length is unknown")
return self.length

def keys(self):
keys = self._retrieve_item(b'__keys__')
if self.length is None:
self.length = len(keys)
return keys

def make_read_transaction(self):
db = lmdb.open(
Expand Down Expand Up @@ -162,18 +182,15 @@ def __getitem__(self, key):
def _retrieve_item(self, key):
"""this method is private and not user-facing, not customizable"""
assert isinstance(key, (bytes, str)), f"{key} is not string or bytes"
orig_key = key
if isinstance(key, str):
key = key.encode("ascii")
res = self.read_txn.get(key)
res = loads(res)
if res is None:
raise KeyError(f"key {orig_key} is not present in the lmdb")
res = self.loads(res)
return res

def __len__(self):
return self.length

def keys(self):
return self._retrieve_item(b'__keys__')


class ImageLMDB(LMDBData):
"""
Expand Down

0 comments on commit ce24638

Please sign in to comment.