Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 152 additions & 21 deletions sdk/python/src/flamepy/core/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import OrderedDict
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

import bson
import cloudpickle
Expand All @@ -26,6 +26,28 @@

from flamepy.core.types import FlameClientCache, FlameClientTls, FlameContext

if TYPE_CHECKING:
import numpy as np

# Magic prefix for fast-path serialization format identification.
# Using "FLM" + version byte to avoid collision with pickle protocol headers.
# Pickle protocols start with \x80 (protocol 2+) or opcodes like \x28, \x5d, etc.
_MAGIC_PREFIX = b"FLM"
_TYPE_CLOUDPICKLE = b"FLM\x00"
_TYPE_NUMPY = b"FLM\x01"
_TYPE_ARROW_TABLE = b"FLM\x02"
_TYPE_ARROW_ARRAY = b"FLM\x03"
_TYPE_ARROW_BATCH = b"FLM\x04"
_MAGIC_PREFIX_LEN = len(_MAGIC_PREFIX) + 1 # 4 bytes total

try:
import numpy as np

_HAS_NUMPY = True
except ImportError:
np = None # type: ignore[assignment]
_HAS_NUMPY = False

logger = logging.getLogger(__name__)

Deserializer = Callable[[Any, List[Any]], Any]
Expand Down Expand Up @@ -195,50 +217,159 @@ def __str__(self) -> str:
return self.to_key() or self.to_prefix()


def _serialize_numpy(arr: "np.ndarray") -> bytes:
"""Serialize numpy array using Arrow's zero-copy tensor format."""
tensor = pa.Tensor.from_numpy(arr)
sink = pa.BufferOutputStream()
sink.write(_TYPE_NUMPY)
pa.ipc.write_tensor(tensor, sink)
return sink.getvalue().to_pybytes()


def _deserialize_numpy(data: bytes) -> "np.ndarray":
"""Deserialize numpy array from Arrow tensor format."""
reader = pa.BufferReader(data)
tensor = pa.ipc.read_tensor(reader)
return tensor.to_numpy()


def _serialize_arrow_table(table: pa.Table) -> bytes:
"""Serialize PyArrow Table using IPC stream format."""
sink = pa.BufferOutputStream()
sink.write(_TYPE_ARROW_TABLE)
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue().to_pybytes()


def _deserialize_arrow_table(data: bytes) -> pa.Table:
"""Deserialize PyArrow Table from IPC stream format."""
reader = pa.ipc.open_stream(pa.BufferReader(data))
return reader.read_all()


def _serialize_arrow_batch(batch: pa.RecordBatch) -> bytes:
"""Serialize PyArrow RecordBatch using IPC stream format."""
sink = pa.BufferOutputStream()
sink.write(_TYPE_ARROW_BATCH)
with pa.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)
return sink.getvalue().to_pybytes()


def _deserialize_arrow_batch(data: bytes) -> pa.RecordBatch:
"""Deserialize PyArrow RecordBatch from IPC stream format."""
reader = pa.ipc.open_stream(pa.BufferReader(data))
return reader.read_next_batch()


def _serialize_arrow_array(arr: pa.Array) -> bytes:
"""Serialize PyArrow Array by wrapping in a RecordBatch."""
batch = pa.RecordBatch.from_arrays([arr], names=["data"])
sink = pa.BufferOutputStream()
sink.write(_TYPE_ARROW_ARRAY)
with pa.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)
return sink.getvalue().to_pybytes()


def _deserialize_arrow_array(data: bytes) -> pa.Array:
"""Deserialize PyArrow Array from IPC stream format."""
reader = pa.ipc.open_stream(pa.BufferReader(data))
batch = reader.read_next_batch()
return batch.column(0)


def _serialize_cloudpickle(obj: Any) -> bytes:
"""Serialize using cloudpickle (fallback for arbitrary Python objects)."""
return _TYPE_CLOUDPICKLE + cloudpickle.dumps(obj, protocol=cloudpickle.DEFAULT_PROTOCOL)


def _deserialize_cloudpickle(data: bytes) -> Any:
"""Deserialize using cloudpickle."""
return cloudpickle.loads(data)


def _serialize_object_data(obj: Any) -> bytes:
"""Serialize object using the optimal format based on type.

Fast-path for numpy arrays and PyArrow types avoids cloudpickle overhead.
Falls back to cloudpickle for arbitrary Python objects.
"""
if _HAS_NUMPY and isinstance(obj, np.ndarray):
if obj.flags.c_contiguous or obj.flags.f_contiguous:
return _serialize_numpy(obj)

if isinstance(obj, pa.Table):
return _serialize_arrow_table(obj)

if isinstance(obj, pa.RecordBatch):
return _serialize_arrow_batch(obj)

if isinstance(obj, pa.Array):
return _serialize_arrow_array(obj)

return _serialize_cloudpickle(obj)


def _deserialize_object_data(data: bytes) -> Any:
"""Deserialize object, detecting format from magic prefix."""
if len(data) < _MAGIC_PREFIX_LEN:
return cloudpickle.loads(data)

prefix = data[:_MAGIC_PREFIX_LEN]
payload = data[_MAGIC_PREFIX_LEN:]

if prefix == _TYPE_NUMPY:
if not _HAS_NUMPY:
raise ImportError("numpy is required to deserialize this object")
return _deserialize_numpy(payload)

if prefix == _TYPE_ARROW_TABLE:
return _deserialize_arrow_table(payload)

if prefix == _TYPE_ARROW_BATCH:
return _deserialize_arrow_batch(payload)

if prefix == _TYPE_ARROW_ARRAY:
return _deserialize_arrow_array(payload)

if prefix == _TYPE_CLOUDPICKLE:
return _deserialize_cloudpickle(payload)

return cloudpickle.loads(data)


def _serialize_object(obj: Any) -> pa.RecordBatch:
"""Serialize a Python object to an Arrow RecordBatch.

Args:
obj: The object to serialize

Returns:
RecordBatch with schema {version: uint64, data: binary}
Uses fast-path serialization for numpy arrays and PyArrow types,
falling back to cloudpickle for arbitrary Python objects.
"""
# Serialize the object using cloudpickle
data_bytes = cloudpickle.dumps(obj, protocol=cloudpickle.DEFAULT_PROTOCOL)
data_bytes = _serialize_object_data(obj)

# Create Arrow schema
schema = pa.schema(
[
pa.field("version", pa.uint64()),
pa.field("data", pa.binary()),
]
)

# Create RecordBatch
version_array = pa.array([0], type=pa.uint64())
data_array = pa.array([data_bytes], type=pa.binary())

batch = pa.RecordBatch.from_arrays([version_array, data_array], schema=schema)

return batch
return pa.RecordBatch.from_arrays([version_array, data_array], schema=schema)


def _deserialize_object(batch: pa.RecordBatch) -> Any:
"""Deserialize a Python object from an Arrow RecordBatch.

Args:
batch: RecordBatch with schema {version: uint64, data: binary}

Returns:
The deserialized object
Automatically detects the serialization format from the type marker.
"""
# Extract data from the batch
data_array = batch.column("data")
data_bytes = data_array[0].as_py()

# Deserialize using cloudpickle
return cloudpickle.loads(data_bytes)
return _deserialize_object_data(data_bytes)


_client_pool: Dict[str, flight.FlightClient] = {}
Expand Down
144 changes: 144 additions & 0 deletions sdk/python/tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import threading

import bson
import numpy as np
import pyarrow as pa

from flamepy.core.cache import (
Object,
ObjectKey,
ObjectRef,
_cache_lock,
_deserialize_object,
_deserialize_object_data,
_MAGIC_PREFIX_LEN,
_object_cache,
_serialize_object,
_serialize_object_data,
_TYPE_ARROW_ARRAY,
_TYPE_ARROW_BATCH,
_TYPE_ARROW_TABLE,
_TYPE_CLOUDPICKLE,
_TYPE_NUMPY,
delete_objects,
)

Expand Down Expand Up @@ -42,6 +52,140 @@ def test_serialize_handles_various_types(self):
assert result == original


class TestFastPathSerialization:
def test_numpy_array_uses_fast_path(self):
arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float64)
data = _serialize_object_data(arr)

assert data[:_MAGIC_PREFIX_LEN] == _TYPE_NUMPY
result = _deserialize_object_data(data)
np.testing.assert_array_equal(result, arr)

def test_numpy_array_various_dtypes(self):
test_cases = [
np.array([1, 2, 3], dtype=np.int32),
np.array([1.5, 2.5, 3.5], dtype=np.float32),
np.array([1, 2, 3], dtype=np.int64),
np.array([[1, 2], [3, 4]], dtype=np.float64),
np.zeros((10, 10, 3), dtype=np.uint8),
]

for original in test_cases:
data = _serialize_object_data(original)
assert data[:_MAGIC_PREFIX_LEN] == _TYPE_NUMPY
result = _deserialize_object_data(data)
np.testing.assert_array_equal(result, original)
assert result.dtype == original.dtype

def test_numpy_large_array_performance(self):
import time

large_arr = np.random.rand(1000, 1000)

start = time.perf_counter()
data = _serialize_object_data(large_arr)
serialize_time = time.perf_counter() - start

start = time.perf_counter()
result = _deserialize_object_data(data)
deserialize_time = time.perf_counter() - start

np.testing.assert_array_almost_equal(result, large_arr)
assert data[:_MAGIC_PREFIX_LEN] == _TYPE_NUMPY
assert serialize_time < 0.5
assert deserialize_time < 0.5

def test_pyarrow_table_uses_fast_path(self):
table = pa.table({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})
data = _serialize_object_data(table)

assert data[:_MAGIC_PREFIX_LEN] == _TYPE_ARROW_TABLE
result = _deserialize_object_data(data)
assert result.equals(table)

def test_pyarrow_record_batch_uses_fast_path(self):
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]})
data = _serialize_object_data(batch)

assert data[:_MAGIC_PREFIX_LEN] == _TYPE_ARROW_BATCH
result = _deserialize_object_data(data)
assert result.equals(batch)

def test_pyarrow_array_uses_fast_path(self):
arr = pa.array([1, 2, 3, 4, 5])
data = _serialize_object_data(arr)

assert data[:_MAGIC_PREFIX_LEN] == _TYPE_ARROW_ARRAY
result = _deserialize_object_data(data)
assert result.equals(arr)

def test_pyarrow_chunked_array_uses_cloudpickle(self):
chunked = pa.chunked_array([[1, 2], [3, 4]])
data = _serialize_object_data(chunked)

assert data[:_MAGIC_PREFIX_LEN] == _TYPE_CLOUDPICKLE
result = _deserialize_object_data(data)
assert result.equals(chunked)

def test_dict_uses_cloudpickle(self):
obj = {"key": "value", "number": 42}
data = _serialize_object_data(obj)

assert data[:_MAGIC_PREFIX_LEN] == _TYPE_CLOUDPICKLE
result = _deserialize_object_data(data)
assert result == obj

def test_list_uses_cloudpickle(self):
obj = [1, 2, 3, "mixed", {"nested": True}]
data = _serialize_object_data(obj)

assert data[:_MAGIC_PREFIX_LEN] == _TYPE_CLOUDPICKLE
result = _deserialize_object_data(data)
assert result == obj

def test_full_roundtrip_via_record_batch(self):
test_cases = [
np.array([1.0, 2.0, 3.0]),
pa.table({"a": [1, 2, 3]}),
pa.array([10, 20, 30]),
{"python": "dict"},
[1, 2, 3],
]

for original in test_cases:
batch = _serialize_object(original)
result = _deserialize_object(batch)

if isinstance(original, np.ndarray):
np.testing.assert_array_equal(result, original)
elif isinstance(original, (pa.Table, pa.Array)):
assert result.equals(original)
else:
assert result == original

def test_non_contiguous_numpy_array_uses_cloudpickle(self):
arr = np.array([[1, 2, 3], [4, 5, 6]])
non_contiguous = arr[:, ::2]
assert not non_contiguous.flags.c_contiguous
assert not non_contiguous.flags.f_contiguous

data = _serialize_object_data(non_contiguous)
assert data[:_MAGIC_PREFIX_LEN] == _TYPE_CLOUDPICKLE

result = _deserialize_object_data(data)
np.testing.assert_array_equal(result, non_contiguous)

def test_fortran_contiguous_array_uses_fast_path(self):
arr = np.asfortranarray([[1, 2], [3, 4], [5, 6]])
assert arr.flags.f_contiguous

data = _serialize_object_data(arr)
assert data[:_MAGIC_PREFIX_LEN] == _TYPE_NUMPY

result = _deserialize_object_data(data)
np.testing.assert_array_equal(result, arr)


class TestClientSideCaching:
def setup_method(self):
with _cache_lock:
Expand Down
Loading