Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ with open("data.im", "wb") as f:
# You can skip computing or checking CRCs, e.g. if your
# embedded object already contains CRCs
mb = MapBuffer(..., check_crc=False, compute_crc=False)

# If your access pattern is such that the index and the
# download are similar in size (e.g. watershed meshes)
# you can cache the index.
mb = MapBuffer(..., index_cache="/tmp/helloworld.mbi")
```

## Installation
Expand Down
93 changes: 93 additions & 0 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import mmap
import os
import random
from unittest.mock import patch

import numpy as np

from mapbuffer import ValidationError, IntMap, MapBuffer, HEADER_LENGTH

CACHE_PATH = "./test_index_cache.mbi"

@pytest.mark.parametrize("compress", (None, "gzip", "br", "zstd", "lzma"))
def test_empty(compress):
mbuf = MapBuffer({}, compress=compress)
Expand Down Expand Up @@ -248,6 +251,96 @@ def test_set_object_intmap():
except KeyError:
pass

@pytest.fixture(autouse=True)
def cleanup_cache():
"""Ensure cache file is removed before and after each test."""
if os.path.exists(CACHE_PATH):
os.remove(CACHE_PATH)
yield
if os.path.exists(CACHE_PATH):
os.remove(CACHE_PATH)


def make_mapbuffer(data=None, **kwargs):
data = data or {1: b"hello", 2: b"world"}
return MapBuffer(data, index_cache=CACHE_PATH, **kwargs)


def test_index_cache_file_is_created():
"""Cache file should be written after first access."""
mbuf = make_mapbuffer()
mbuf.index()
assert os.path.exists(CACHE_PATH)


def test_index_cache_header_and_index_written():
"""Cache file should contain header + full index bytes."""
mbuf = make_mapbuffer()
index = mbuf.index()

with open(CACHE_PATH, "rb") as f:
cached = f.read()

assert len(cached) == HEADER_LENGTH + index.nbytes


def test_index_cache_is_loaded_from_disk():
"""Second MapBuffer with same cache should read index from disk, not buffer."""
mbuf = make_mapbuffer()
original_index = mbuf.index().copy()

# Reload — this time the cache exists, so index should come from disk
mbuf2 = make_mapbuffer()
mbuf2._index = None # ensure not inherited

with patch.object(np, "frombuffer", wraps=np.frombuffer) as mock_frombuffer:
loaded_index = mbuf2.index()
# np.frombuffer should NOT be called on the main buffer for the index
for call in mock_frombuffer.call_args_list:
args, kwargs = call
# Ensure we're not reading index from the primary buffer
assert kwargs.get("offset") != HEADER_LENGTH, \
"Index was re-read from buffer instead of cache"

np.testing.assert_array_equal(loaded_index, original_index)


def test_index_cache_values_correct():
"""Values retrieved using cache should match those from a non-cached buffer."""
mbuf_cached = make_mapbuffer()
mbuf_plain = MapBuffer({1: b"hello", 2: b"world"})

for key in [1, 2]:
assert mbuf_cached[key] == mbuf_plain[key]


def test_crc_error_raised_despite_cache():
"""CRC validation should still catch corruption even when cache exists."""
data = {1: b"hello", 2: b"world"}
mbuf = make_mapbuffer(data)
mbuf.index() # populate cache

# Corrupt the data region in the buffer
buf = bytearray(mbuf.buffer)
idx = bytes(buf).index(b"hello")
buf[idx] = ord(b"H")
mbuf.buffer = bytes(buf)
mbuf._index = None # force re-read so cache is used but data is still corrupt

with pytest.raises(ValidationError):
mbuf[1]


def test_index_cache_not_rewritten_if_already_complete():
"""Cache file should not be overwritten on second load."""
mbuf = make_mapbuffer()
mbuf.index()

mtime_after_first = os.path.getmtime(CACHE_PATH)

mbuf2 = make_mapbuffer()
mbuf2.index()

mtime_after_second = os.path.getmtime(CACHE_PATH)
assert mtime_after_first == mtime_after_second, \
"Cache file was unexpectedly rewritten on second access"
52 changes: 50 additions & 2 deletions mapbuffer/mapbuffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Any, Union, Literal
from collections.abc import Callable
import os

import mmap
import io
Expand All @@ -9,6 +10,7 @@
from . import compression

import crc32c
import fasteners
import numpy as np

import mapbufferaccel
Expand All @@ -21,8 +23,8 @@ class MapBuffer:
"""Represents a usable int->bytes dictionary as a byte string."""
__slots__ = (
"data", "tobytesfn", "frombytesfn",
"dtype", "buffer", "check_crc", "compute_crc",
"_header", "_index", "_compress"
"dtype", "buffer", "check_crc", "compute_crc", "index_cache",
"_header", "_index", "_compress", "_lock"
)
def __init__(
self,
Expand All @@ -32,6 +34,7 @@ def __init__(
frombytesfn:Optional[Callable[[bytes], Any]] = None,
check_crc:bool = True,
compute_crc:bool = True,
index_cache:Optional[str] = None,
):
"""
data: dict (int->byte serializable object) or bytes
Expand All @@ -52,10 +55,14 @@ def __init__(
self.buffer = None
self.check_crc = check_crc
self.compute_crc = compute_crc
self.index_cache = index_cache

self._header = None
self._index = None
self._compress = None
self._lock = None
if self.index_cache is not None:
self._lock = fasteners.InterProcessReaderWriterLock(self.index_cache)

if isinstance(data, dict):
self.buffer = self.dict2buf(data, compress)
Expand Down Expand Up @@ -102,9 +109,29 @@ def header(self):
if self._header is not None:
return self._header

if self.index_cache is not None:
if os.path.exists(self.index_cache):
with self._lock.read_lock():
with open(self.index_cache, "rb") as f:
self._header = f.read(HEADER_LENGTH)

if len(self._header) == HEADER_LENGTH:
return self._header

# seems dumb, buf if self.buffer is an object that
# requires network access, this is a valuable cache
self._header = self.buffer[:HEADER_LENGTH]

if self.index_cache is not None:
with self._lock.write_lock():
try:
if os.path.getsize(self.index_cache) < HEADER_LENGTH:
with open(self.index_cache, "wb") as f:
f.write(self._header)
except FileNotFoundError:
with open(self.index_cache, "wb") as f:
f.write(self._header)

return self._header

def index(self):
Expand All @@ -115,6 +142,18 @@ def index(self):
N = len(self)
index_length = 2 * N

if self.index_cache is not None:
try:
if os.path.getsize(self.index_cache) > HEADER_LENGTH:
with self._lock.read_lock():
with open(self.index_cache, "rb") as f:
f.seek(HEADER_LENGTH)
index = f.read(index_length * 8)
self._index = np.frombuffer(index, dtype=np.uint64).reshape((N,2))
return self._index
except FileNotFoundError:
pass

if isinstance(self.buffer, (bytes,bytearray,np.ndarray,mmap.mmap)):
self._index = np.frombuffer(
self.buffer,
Expand All @@ -127,6 +166,15 @@ def index(self):
index = self.buffer[HEADER_LENGTH:index_length+HEADER_LENGTH]
self._index = np.frombuffer(index, dtype=np.uint64).reshape((N,2))

if self.index_cache is not None:
try:
if os.path.getsize(self.index_cache) == HEADER_LENGTH:
with self._lock.write_lock():
with open(self.index_cache, "ab") as f:
f.write(self._index.tobytes('C'))
except FileNotFoundError:
pass

return self._index

def keys(self):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
brotli
crc32c
deflate>=0.2.0
fasteners
numpy
tqdm
zstandard