From 73ed2f899eb1593c6955940511cd7c78a3aa6d8a Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Mon, 18 May 2026 18:30:08 -0700 Subject: [PATCH 1/2] Implement native tabular cache path --- docs/designs/RFE318-cache/CHANGELOG.md | 4 + docs/designs/RFE318-cache/FS.md | 251 ++++++++++++---- docs/designs/RFE318-cache/STATUS.md | 17 +- e2e/tests/test_cache.py | 33 +++ object_cache/src/cache.rs | 384 +++++++++++++++++++++---- object_cache/src/storage/disk.rs | 149 ++++++++-- sdk/python/src/flamepy/core/cache.py | 203 ++++++++++++- sdk/python/tests/test_cache.py | 109 +++++++ 8 files changed, 1021 insertions(+), 129 deletions(-) diff --git a/docs/designs/RFE318-cache/CHANGELOG.md b/docs/designs/RFE318-cache/CHANGELOG.md index 8157e2e1..42e62328 100644 --- a/docs/designs/RFE318-cache/CHANGELOG.md +++ b/docs/designs/RFE318-cache/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Design for issue #318 item 3: native DataSet/DataFrame cache payloads stored as Arrow schemas and record batches instead of opaque binary rows. +- Native Arrow table payload support in object-cache Flight and disk storage paths. +- FlamePy native tabular payload classification for PyArrow tables/batches plus optional pandas/polars DataFrames. +- Unit and E2E coverage for native Arrow cache payloads. - `patch` operation support in `ObjectCache` and `FlightCacheServer` (PR #6) - `patch_object` function in Python SDK (PR #6) - Append-only semantics for object updates (PR #6) diff --git a/docs/designs/RFE318-cache/FS.md b/docs/designs/RFE318-cache/FS.md index 28b196f8..92d3956c 100644 --- a/docs/designs/RFE318-cache/FS.md +++ b/docs/designs/RFE318-cache/FS.md @@ -21,6 +21,7 @@ This design aims to improve the object cache implementation by leveraging Apache 4. **Scalability**: Support both local and remote cache access patterns to enable distributed caching scenarios. 5. **Standardization**: Use Arrow Flight as the communication protocol, which is a standard for high-performance data services. 6. **Incremental Updates**: Support `patch` operation for appending delta data to objects, enabling efficient distributed operations with multiple clients. +7. **Native Tabular Payloads**: Store DataSet/DataFrame-style payloads directly as Arrow data instead of wrapping them as pickled or opaque IPC bytes. ## 2. Function Specification @@ -97,12 +98,12 @@ The cache server implements the Arrow Flight protocol with the following operati - Behavior: Persists data to disk using Arrow IPC, returns ObjectRef 2. **do_get**: Retrieve an object from the cache - - Request: Ticket containing key (`ssn_id/object_id`) + - Request: Ticket containing key (`app_name/session_id/object_id`) and optional cached version suffix - Response: Streaming FlightData containing RecordBatch with object data - Behavior: Reads base object and all deltas from disk, returns combined data 3. **get_flight_info**: Get metadata about a flight - - Request: FlightDescriptor with path (`{session_id}/{object_id}`) + - Request: FlightDescriptor with path (`{app_name}/{session_id}/{object_id}`) - Response: FlightInfo with schema information 4. **list_flights**: List all cached objects @@ -117,6 +118,110 @@ The cache server implements the Arrow Flight protocol with the following operati - `{app}/*` - delete all objects across all sessions of an application (wildcard) - **PATCH**: Append delta data to an existing object (new) +### Native DataSet/DataFrame Cache Path + +Issue #318 item 3 scopes the next cache enhancement to: **For DataSet/DataFrame, put to cache directly**. In this design, "DataSet/DataFrame" means tabular Python payloads that can be represented losslessly as Arrow batches without cloudpickle: + +- `pyarrow.Table` +- `pyarrow.RecordBatch` +- `pandas.DataFrame` when pandas is installed +- `polars.DataFrame` and collected `polars.LazyFrame` when polars is installed +- Dataset-like objects that expose an Arrow table through `to_arrow()`, `__arrow_c_stream__`, or an adapter registered by FlamePy + +The current Python fast path avoids cloudpickle for PyArrow tables, but it still serializes the table into Arrow IPC bytes and stores those bytes in the opaque `{version, data}` cache row. The direct tabular path must avoid that wrapper. The cache should stream and persist the original Arrow schema and record batches as the cached object payload. + +**Public API Behavior:** + +- `put_object(key_prefix, obj)` remains the entry point. It classifies the payload before writing: + - Tabular payloads use the native Arrow path. + - All other objects use the existing opaque object path. +- `get_object(ref)` returns the original logical object type when the required optional dependency is installed: + - PyArrow inputs return `pyarrow.Table` or `pyarrow.RecordBatch`. + - pandas inputs return `pandas.DataFrame`. + - polars inputs return `polars.DataFrame`. + - Generic Dataset inputs return the registered adapter output; if no adapter is available, they return `pyarrow.Table`. +- If a consumer does not have the optional library needed to reconstruct the original type, FlamePy returns `pyarrow.Table` rather than failing for pandas/polars-compatible payloads. Adapter-backed Dataset types may raise `ImportError` if the adapter cannot be loaded. +- `ObjectRef` remains `{endpoint, key, version}`. Payload type is cache metadata, not part of the reference. +- `update_object(ref, new_obj)` supports the same payload classification as `put_object`; updating a direct tabular object rewrites the base Arrow object and clears deltas. +- `patch_object(ref, delta)` stays on the existing opaque delta path in the first implementation. Native tabular append/merge semantics are intentionally out of scope for this item. + +**Payload Metadata:** + +The cache reserves `flame.cache.*` schema metadata keys for native payloads: + +| Key | Value | +|-----|-------| +| `flame.cache.format` | `opaque-v1` or `arrow-table-v1` | +| `flame.cache.version` | Current object version as decimal text | +| `flame.cache.logical_type` | `pyarrow.table`, `pyarrow.record_batch`, `pandas.dataframe`, `polars.dataframe`, or adapter name | +| `flame.cache.adapter` | Optional adapter identifier for Dataset-like objects | + +User-provided Arrow schema metadata must be preserved. When user metadata collides with `flame.cache.*`, the cache-owned value wins for transport and persistence. + +**Flight Protocol Behavior:** + +- Native tabular `do_put` requests carry an Arrow schema with `flame.cache.format=arrow-table-v1` and stream the table's record batches directly. +- Opaque object `do_put` requests keep using the current wrapper schema with `version` and `data` fields. +- Native tabular `do_get` responses stream the stored Arrow schema and batches directly. They do not wrap rows in the opaque response schema. +- Opaque object `do_get` responses keep using the response schema with `version`, `kind`, and `data` fields so existing patch and incremental-read behavior remains unchanged. +- `get_flight_info` should return the native table schema for tabular objects when the object is known. For opaque objects it may keep returning an empty schema for backward compatibility. + +**Storage Behavior:** + +- Opaque objects remain compatible with existing files that use the `version/data` Arrow IPC schema. +- Native tabular objects are stored as Arrow IPC files using the original table schema plus the reserved cache metadata. +- Storage paths and object keys do not change: + - `{storage_path}/{app_name}/{session_id}/{object_id}.arrow` + - ObjectRef key: `{app_name}/{session_id}/{object_id}` +- Cache loading must detect both formats: + - `flame.cache.format=arrow-table-v1` means native tabular object. + - Existing `version/data` files without metadata are legacy opaque objects. +- Size accounting for eviction should use stored IPC file size when available, falling back to Arrow buffer sizes for in-memory-only storage. + +**Implementation Shape:** + +Rust cache objects should model payload type explicitly instead of assuming every object is a byte vector: + +```rust +pub enum ObjectPayload { + Opaque(Vec), + ArrowTable { + schema: Arc, + batches: Vec, + logical_type: Option, + adapter: Option, + }, +} + +pub struct Object { + pub version: u64, + pub payload: ObjectPayload, + pub deltas: Vec, +} +``` + +The existing `Object { version, data, deltas }` representation remains the legacy opaque representation during migration. New server code should convert legacy data into `ObjectPayload::Opaque` at load boundaries so the rest of the cache can dispatch by payload kind. + +**Python Payload Classification:** + +FlamePy should classify tabular payloads without adding mandatory pandas, polars, or datasets dependencies: + +1. Check exact PyArrow types first (`pa.Table`, `pa.RecordBatch`, `pa.RecordBatchReader`). +2. Check optional pandas/polars types only when those modules are importable. +3. Check adapter registry entries for Dataset-like objects. +4. Check Arrow protocol methods such as `__arrow_c_stream__` or `to_arrow()`. +5. Fall back to the existing opaque object serializer. + +The classifier should be conservative: if conversion to `pyarrow.Table` is lossy or ambiguous, use the opaque path. + +**Out of Scope for Item 3:** + +- Designing row-level tabular patch/merge semantics. +- Distributed Dataset partition placement or cache-side query execution. +- Changing `ObjectRef`. +- Making pandas, polars, or datasets required dependencies. +- Migrating existing opaque objects that contain pickled DataFrames. + ### Patch Operation Semantics @@ -161,8 +266,8 @@ The ObjectRef structure is updated to include: @dataclass class ObjectRef: endpoint: str # The endpoint of cache server (e.g., "grpc://127.0.0.1:9090") - key: str # The key of object (e.g., "ssn_id/object_id") - version: int # Version number (always 0 for now) + key: str # The key of object (e.g., "app/session/object") + version: int # Server-managed version; 0 forces a fresh read ``` **Error Handling:** @@ -199,22 +304,25 @@ To manage disk usage, the cache implements a Least Recently Used (LRU) eviction - Arrow Flight server implementation - Arrow IPC persistence to disk - Python SDK integration with Arrow Flight client +- Native Arrow storage and transport for DataSet/DataFrame-style payloads - Local and remote cache access patterns - ObjectRef structure updates -- Key-based storage organization (`ssn_id/object_id`) +- Key-based storage organization (`app_name/session_id/object_id`) - Support for both public IP and localhost binding - Implementation updates for common data in both RL and service modules **Out of Scope:** -- Version checking and conflict resolution (version always 0 for now) +- Client-side version conflict resolution - Distributed cache coordination - Cache replication - Authentication and authorization - Cache size limits and quotas - Cache statistics and monitoring (beyond basic logging) +- Native tabular patch/merge semantics +- Requiring pandas, polars, or datasets as core SDK dependencies **Limitations:** -- Version is always 0; no version conflict detection +- Object versions are server-managed; clients do not perform conflict resolution - No automatic cache cleanup or eviction - Single-node cache server (no distributed coordination) - No authentication/authorization mechanisms @@ -235,6 +343,8 @@ To manage disk usage, the cache implements a Least Recently Used (LRU) eviction 4. **Python SDK cache.py**: Replace HTTP-based implementation with Arrow Flight client 5. **RL Module**: Update to use new ObjectRef structure 6. **Agent Module**: Update to use new ObjectRef structure +7. **Python SDK tabular payload classifier**: Detect PyArrow, pandas, polars, and Dataset-like objects that can be represented as Arrow tables. +8. **Rust cache payload model**: Distinguish opaque and native Arrow table payloads in storage, Flight responses, and metadata. **Integration Points:** - Arrow Flight server integrates with gRPC/tonic @@ -279,8 +389,9 @@ The flame-object-cache component is a standalone Rust service that implements an └──► Disk Storage (Arrow IPC files) /storage_path/ - └── ssn_id/ - └── object_id.arrow + └── app_name/ + └── session_id/ + └── object_id.arrow ``` ### Components @@ -321,9 +432,20 @@ The flame-object-cache component is a standalone Rust service that implements an **Object (Rust):** ```rust +pub enum ObjectPayload { + Opaque(Vec), + ArrowTable { + schema: Arc, + batches: Vec, + logical_type: Option, + adapter: Option, + }, +} + pub struct Object { pub version: u64, - pub data: Vec, + pub payload: ObjectPayload, + pub deltas: Vec, } ``` @@ -365,15 +487,15 @@ class ObjectKey: @dataclass class ObjectRef: endpoint: str # Cache server endpoint - key: str # "ssn_id/object_id" - version: int # Always 0 for now + key: str # "app/session/object" + version: int # Server-managed version; 0 forces a fresh read ``` **ObjectMetadata (Rust):** ```rust pub struct ObjectMetadata { pub endpoint: String, - pub key: String, // New field: "ssn_id/object_id" + pub key: String, // "app/session/object" pub version: u64, pub size: u64, } @@ -388,8 +510,9 @@ pub struct ObjectMetadata { - Wildcard key format: `{app_name}/*` (for delete operations across all sessions) **Arrow IPC File Format:** -- Each object stored as a single RecordBatch in Arrow IPC file -- Schema: `{version: UInt64, data: Binary}` +- Opaque objects are stored as a single RecordBatch in Arrow IPC file +- Opaque schema: `{version: UInt64, data: Binary}` +- Native tabular objects are stored as Arrow IPC files with their original Arrow schema and reserved `flame.cache.*` schema metadata - File naming: `{object_id}.arrow` @@ -405,24 +528,34 @@ pub struct ObjectMetadata { **do_put Algorithm:** 1. Receive FlightData stream with RecordBatch -2. Extract session_id from app_metadata -3. Generate unique object_id (UUID) -4. Construct key: `{session_id}/{object_id}` -5. Create session directory if it doesn't exist -6. Write RecordBatch to Arrow IPC file: `{storage_path}/{session_id}/{object_id}.arrow` -7. Update in-memory index: `HashMap` -8. Construct ObjectRef: `{endpoint, key, version: 0}` (using public endpoint from ObjectCache) -9. Serialize ObjectRef to BSON -10. Return PutResult with ObjectRef in app_metadata +2. Extract key prefix and optional object ID from the Flight descriptor +3. Inspect schema metadata: + - `flame.cache.format=arrow-table-v1`: collect batches as native tabular payload + - No native metadata: decode the legacy opaque `{version, data}` payload +4. Generate unique object_id (UUID) when the descriptor only contains a prefix +5. Construct key: `{app_name}/{session_id}/{object_id}` +6. Create session directory if it doesn't exist +7. Write payload to Arrow IPC file: + - Opaque: `{version, data}` wrapper schema + - Native tabular: original schema and record batches +8. Update in-memory index: `HashMap` +9. Construct ObjectRef: `{endpoint, key, version}` (using public endpoint from ObjectCache) +10. Serialize ObjectRef to BSON +11. Return PutResult with ObjectRef in app_metadata **do_get Algorithm:** 1. Extract key from Ticket 2. Parse key to get session_id and object_id 3. Check in-memory index for key -4. If found, read Arrow IPC file: `{storage_path}/{session_id}/{object_id}.arrow` -5. Deserialize RecordBatch from file -6. Convert RecordBatch to FlightData -7. Stream FlightData to client +4. If found, read Arrow IPC file: `{storage_path}/{app_name}/{session_id}/{object_id}.arrow` +5. Detect stored payload kind +6. Opaque payload: + - Deserialize object wrapper and optional deltas + - Convert response rows to FlightData using `{version, kind, data}` +7. Native tabular payload: + - Stream original schema and record batches directly + - Preserve user metadata and reserved cache metadata +8. Stream FlightData to client **list_flights Algorithm:** 1. Get cache service's public endpoint from ObjectCache (obtained during server construction from flame-cluster.yaml) @@ -431,25 +564,28 @@ pub struct ObjectMetadata { - List all `.arrow` files - For each file: - Extract object_id from filename - - Construct key: `{session_id}/{object_id}` + - Construct key: `{app_name}/{session_id}/{object_id}` - Create FlightInfo with key as ticket and cache service's public endpoint - Stream FlightInfo to client **Python SDK put_object Algorithm:** -1. Check if `cache.storage` is set -2. If set: - - Serialize object to RecordBatch +1. Classify the payload: + - Native tabular: convert to Arrow schema and record batches with `flame.cache.format=arrow-table-v1` + - Opaque: serialize to RecordBatch with `{version, data}` +2. Check if `cache.storage` is set +3. If set: + - Write payload to local Arrow IPC storage using the selected format - Generate object_id (UUID) - - Write to local storage: `{cache.storage}/{session_id}/{object_id}.arrow` + - Write to local storage: `{cache.storage}/{app_name}/{session_id}/{object_id}.arrow` - Connect to cache server using `cache.endpoint` - - Get flight info using FlightDescriptor with path `{session_id}/{object_id}` - - Construct ObjectRef with cache server's endpoint from FlightInfo, key `{session_id}/{object_id}`, and version 0 -3. If not set: + - Get flight info using FlightDescriptor with path `{app_name}/{session_id}/{object_id}` + - Construct ObjectRef with cache server's endpoint from FlightInfo, key `{app_name}/{session_id}/{object_id}`, and server version from cache metadata +4. If not set: - Check if `cache.endpoint` is set, else raise exception - Connect to remote cache server via endpoint - - Call do_put to upload object + - Call do_put to upload the selected payload format - Extract ObjectRef from PutResult app_metadata -4. Return ObjectRef +5. Return ObjectRef **Key Construction:** - Format: `{app_name}/{session_id}/{object_id}` @@ -543,13 +679,13 @@ The special session_id value `*` (constant: `WILDCARD_SESSION`) indicates all se **Example 1: Python SDK Client Uploading Object to Remote Cache** - Description: A Python SDK client uploads an object to a remote cache server - Step-by-step workflow: - 1. Client calls `put_object(session_id="sess123", obj=my_data)` + 1. Client calls `put_object("app/sess123", my_data)` 2. SDK checks `cache.storage` - not set 3. SDK checks `cache.endpoint` - set to "grpc://cache.example.com:9090" - 4. SDK serializes object to RecordBatch + 4. SDK classifies the payload and serializes it as opaque or native tabular Arrow data 5. SDK connects to cache server via Arrow Flight - 6. SDK calls do_put with RecordBatch and session_id in metadata - 7. Cache server generates object_id, creates session directory, writes Arrow IPC file + 6. SDK calls do_put with FlightDescriptor path `app/sess123` + 7. Cache server generates object_id, creates session directory, writes Arrow IPC file under `app/sess123` 8. Cache server returns ObjectRef in PutResult 9. SDK deserializes ObjectRef and returns to client - Expected outcome: Object is stored on cache server, client receives ObjectRef @@ -557,21 +693,21 @@ The special session_id value `*` (constant: `WILDCARD_SESSION`) indicates all se **Example 2: Python SDK Client Uploading Object to Local Cache** - Description: A Python SDK client uploads an object using local storage - Step-by-step workflow: - 1. Client calls `put_object(session_id="sess123", obj=my_data)` + 1. Client calls `put_object("app/sess123", my_data)` 2. SDK checks `cache.storage` - set to "/tmp/flame_cache" - 3. SDK serializes object to RecordBatch + 3. SDK classifies the payload and prepares opaque or native tabular Arrow data 4. SDK generates object_id (UUID) - 5. SDK writes RecordBatch to `/tmp/flame_cache/sess123/{object_id}.arrow` + 5. SDK writes RecordBatch to `/tmp/flame_cache/app/sess123/{object_id}.arrow` 6. SDK connects to cache server using `cache.endpoint` - 7. SDK gets flight info using FlightDescriptor with path `sess123/{object_id}` - 8. SDK constructs ObjectRef with cache server's endpoint from FlightInfo, key `sess123/{object_id}`, and version 0 + 7. SDK gets flight info using FlightDescriptor with path `app/sess123/{object_id}` + 8. SDK constructs ObjectRef with cache server's endpoint from FlightInfo, key `app/sess123/{object_id}`, and server version from cache metadata 9. SDK returns ObjectRef to client - Expected outcome: Object is stored locally, client receives ObjectRef with remote endpoint from FlightInfo **Example 3: Retrieving Cached Object** - Description: A client retrieves a previously cached object - Step-by-step workflow: - 1. Client has ObjectRef: `{endpoint: "grpc://cache.example.com:9090", key: "sess123/obj456", version: 0}` + 1. Client has ObjectRef: `{endpoint: "grpc://cache.example.com:9090", key: "app/sess123/obj456", version: 1}` 2. Client calls `get_object(ref)` 3. SDK connects to cache server using ref.endpoint 4. SDK calls do_get with ticket = ref.key @@ -590,7 +726,7 @@ The special session_id value `*` (constant: `WILDCARD_SESSION`) indicates all se 3. SDK serializes new object to RecordBatch 4. SDK calls do_put with same key (overwrites existing) 5. Cache server writes new data to existing Arrow IPC file - 6. Cache server updates metadata (version remains 0) + 6. Cache server updates metadata and increments the object version 7. Cache server returns updated ObjectRef 8. SDK returns ObjectRef to client - Expected outcome: Object is updated, same ObjectRef returned @@ -607,9 +743,22 @@ The special session_id value `*` (constant: `WILDCARD_SESSION`) indicates all se 7. Client receives list of all cached objects with their keys and endpoints - Expected outcome: Complete list of all cached objects with their keys and cache service endpoints +**Example 6: Native DataFrame Upload and Retrieval** +- Description: A Python SDK client stores a DataFrame without wrapping it in pickled bytes +- Step-by-step workflow: + 1. Client calls `put_object("app/sess123", dataframe)` + 2. SDK recognizes the DataFrame as a tabular payload and converts it to `pyarrow.Table` + 3. SDK adds reserved schema metadata such as `flame.cache.format=arrow-table-v1` and `flame.cache.logical_type=pandas.dataframe` + 4. SDK streams the table batches directly with Arrow Flight `do_put` + 5. Cache server persists the original Arrow schema and batches in `{storage_path}/app/sess123/{object_id}.arrow` + 6. Client calls `get_object(ref)` + 7. Cache server streams the table schema and batches directly with Arrow Flight `do_get` + 8. SDK reconstructs the original DataFrame type when the dependency is available, otherwise returns `pyarrow.Table` +- Expected outcome: Tabular data avoids cloudpickle and avoids the opaque `{version, data}` wrapper. + ### Advanced Use Cases -**Example 6: RL Module Using Cache for RunnerContext** +**Example 7: RL Module Using Cache for RunnerContext** - Description: RL module stores RunnerContext in cache for remote execution - Step-by-step workflow: 1. RL module creates RunnerService with execution object @@ -623,7 +772,7 @@ The special session_id value `*` (constant: `WILDCARD_SESSION`) indicates all se 9. Remote executor deserializes and uses RunnerContext - Expected outcome: RunnerContext is cached and accessible to remote executors -**Example 7: Cache Server Restart Recovery** +**Example 8: Cache Server Restart Recovery** - Description: Cache server restarts and recovers persisted objects - Step-by-step workflow: 1. Cache server starts up diff --git a/docs/designs/RFE318-cache/STATUS.md b/docs/designs/RFE318-cache/STATUS.md index 662f2fb4..c0da4e3c 100644 --- a/docs/designs/RFE318-cache/STATUS.md +++ b/docs/designs/RFE318-cache/STATUS.md @@ -1,5 +1,20 @@ # Object Cache Implementation Status +## ✅ Implemented: Native DataSet/DataFrame Direct Cache Path + +Issue #318 item 3 is implemented for tabular payloads that can be represented as Arrow data directly in object cache, instead of wrapping them as pickled objects or Arrow IPC bytes inside the opaque `{version, data}` row. + +Scope: +- Detect PyArrow tables/batches, pandas DataFrames, polars DataFrames, and adapter-backed Dataset-like objects in FlamePy without adding mandatory pandas/polars/datasets dependencies. +- Stream and persist native tabular payloads as Arrow schemas and record batches with reserved `flame.cache.*` schema metadata. +- Preserve the existing opaque object path, existing `ObjectRef` shape, and legacy `version/data` cache files. +- Keep native tabular patch/merge semantics out of scope for this item. + +Verification: +- Rust object-cache unit coverage includes native Arrow payload storage, reload, and Flight `do_get` schema/batch streaming. +- Python unit coverage includes native payload classification, direct remote `do_put` schema emission, and native Arrow `do_get` parsing. +- E2E tests were added for native Arrow table put/get and update paths. + ## Migration from Naive Cache ✅ **Completed**: The naive HTTP-based cache in `flame-executor-manager` has been removed. The following changes were made: @@ -15,7 +30,7 @@ The object cache is now provided as an embedded library within the `flame-execut ### Core Implementation - ✅ Rust Arrow Flight server implemented (`object_cache/src/cache.rs`) - ✅ Disk persistence using Arrow IPC format -- ✅ Key-based storage organization (`session_id/object_id`) +- ✅ Key-based storage organization (`app_name/session_id/object_id`) - ✅ In-memory index with HashMap - ✅ Object loading from disk on startup - ✅ Configuration support (flame-cluster.yaml with storage path) diff --git a/e2e/tests/test_cache.py b/e2e/tests/test_cache.py index 36e1bfa4..0bd0792c 100644 --- a/e2e/tests/test_cache.py +++ b/e2e/tests/test_cache.py @@ -15,6 +15,7 @@ import uuid import flamepy.core.cache as cache_module +import pyarrow as pa import pyarrow.flight as flight import pytest from flamepy.core import FlameContext, ObjectRef, get_object, patch_object, put_object, update_object @@ -37,6 +38,38 @@ def test_cache_put_and_get(): assert result == test_data +def test_cache_put_and_get_native_arrow_table(): + """Test that Arrow tables round-trip through the native cache path.""" + key_prefix = f"test-app/test-native-arrow-{uuid.uuid4().hex[:8]}" + table = pa.table({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + + ref = put_object(key_prefix, table) + + assert ref.version == 1 + + result = get_object(ref) + assert isinstance(result, pa.Table) + assert result.to_pydict() == table.to_pydict() + assert b"flame.cache.format" not in (result.schema.metadata or {}) + + +def test_cache_native_arrow_table_update(): + """Test that native Arrow table updates rewrite the base object.""" + key_prefix = f"test-app/test-native-arrow-update-{uuid.uuid4().hex[:8]}" + table = pa.table({"value": [1, 2, 3]}) + updated = pa.table({"value": [4, 5], "label": ["x", "y"]}) + + ref = put_object(key_prefix, table) + updated_ref = update_object(ref, updated) + + assert updated_ref.key == ref.key + assert updated_ref.version == ref.version + 1 + + result = get_object(updated_ref) + assert isinstance(result, pa.Table) + assert result.to_pydict() == updated.to_pydict() + + def test_cache_update(): """Test update operation.""" key_prefix = "test-app/test-session-002" diff --git a/object_cache/src/cache.rs b/object_cache/src/cache.rs index 7618182b..9f2f9ae2 100644 --- a/object_cache/src/cache.rs +++ b/object_cache/src/cache.rs @@ -209,38 +209,109 @@ impl From<&ObjectKey> for String { } } -/// Object with optional delta support -/// Per HLD: deltas field is empty for delta objects themselves -/// Note: This struct is immutable after construction - use with_deltas() to create -/// a new Object with deltas populated rather than mutating an existing one. +pub const CACHE_FORMAT_METADATA_KEY: &str = "flame.cache.format"; +pub const CACHE_VERSION_METADATA_KEY: &str = "flame.cache.version"; +pub const CACHE_FORMAT_ARROW_TABLE: &str = "arrow-table-v1"; + +/// Native payload stored by object cache. +/// +/// Opaque payloads preserve the original `version/data` cache behavior. +/// Arrow table payloads keep their original schema and batches so DataFrame +/// payloads do not get hidden inside a binary cell. +#[derive(Debug, Clone)] +pub enum ObjectPayload { + Opaque(Vec), + ArrowTable { + schema: Arc, + batches: Vec, + }, +} + +impl ObjectPayload { + pub fn size_bytes(&self) -> u64 { + match self { + Self::Opaque(data) => data.len() as u64, + Self::ArrowTable { batches, .. } => batches + .iter() + .map(|batch| batch.get_array_memory_size() as u64) + .sum(), + } + } + + pub fn is_arrow_table(&self) -> bool { + matches!(self, Self::ArrowTable { .. }) + } +} + +/// Object with optional delta support. +/// Deltas are currently opaque payloads; native tabular patch/merge semantics +/// are intentionally left out of RFE318 item 3. #[derive(Debug, Clone)] pub struct Object { pub version: u64, - pub data: Vec, + pub payload: ObjectPayload, pub deltas: Vec, } impl Object { - /// Create a new Object with no deltas + /// Create a new opaque Object with no deltas. pub fn new(version: u64, data: Vec) -> Self { Self { version, - data, + payload: ObjectPayload::Opaque(data), deltas: Vec::new(), } } - /// Create a new Object with deltas - /// This is the preferred way to create an Object with deltas rather than - /// mutating an existing Object's deltas field. + /// Create a native Arrow table Object with no deltas. + pub fn new_arrow_table(version: u64, schema: Arc, batches: Vec) -> Self { + Self { + version, + payload: ObjectPayload::ArrowTable { schema, batches }, + deltas: Vec::new(), + } + } + + /// Create a new opaque Object with deltas. + #[cfg(test)] pub fn with_deltas(version: u64, data: Vec, deltas: Vec) -> Self { Self { version, - data, + payload: ObjectPayload::Opaque(data), deltas, } } + pub fn with_payload(version: u64, payload: ObjectPayload, deltas: Vec) -> Self { + Self { + version, + payload, + deltas, + } + } + + pub fn opaque_data(&self) -> Result<&[u8], FlameError> { + match &self.payload { + ObjectPayload::Opaque(data) => Ok(data), + ObjectPayload::ArrowTable { .. } => Err(FlameError::InvalidState( + "expected opaque object payload".to_string(), + )), + } + } + + pub fn into_opaque_data(self) -> Result, FlameError> { + match self.payload { + ObjectPayload::Opaque(data) => Ok(data), + ObjectPayload::ArrowTable { .. } => Err(FlameError::InvalidState( + "expected opaque object payload".to_string(), + )), + } + } + + pub fn is_arrow_table(&self) -> bool { + self.payload.is_arrow_table() + } + pub fn current_version(&self) -> u64 { self.deltas .iter() @@ -248,7 +319,7 @@ impl Object { } pub fn size_bytes(&self) -> u64 { - self.data.len() as u64 + self.deltas.iter().map(Object::size_bytes).sum::() + self.payload.size_bytes() + self.deltas.iter().map(Object::size_bytes).sum::() } } @@ -476,7 +547,8 @@ impl ObjectCache { }; let new_version = current_version + 1; - let versioned_object = Object::new(new_version, object.data); + let versioned_payload = version_payload(object.payload, new_version)?; + let versioned_object = Object::with_payload(new_version, versioned_payload, Vec::new()); let size = versioned_object.size_bytes(); self.storage.write_object(&key, &versioned_object).await?; @@ -565,7 +637,7 @@ impl ObjectCache { let current_version = current_object.current_version(); let new_version = current_version + 1; - let versioned_delta = Object::new(new_version, delta.data); + let versioned_delta = Object::new(new_version, delta.into_opaque_data()?); let mut meta = self.storage.patch_object(key, &versioned_delta).await?; let mut patched_object = current_object; @@ -696,7 +768,16 @@ impl FlightCacheServer { async fn collect_batches_from_stream( mut stream: Streaming, - ) -> Result<(String, Option, Option, Vec), FlameError> { + ) -> Result< + ( + String, + Option, + Option, + Arc, + Vec, + ), + FlameError, + > { let mut batches = Vec::new(); let mut session_id: Option = None; let mut object_id: Option = None; @@ -764,8 +845,10 @@ impl FlightCacheServer { "key must be provided in descriptor path or command".to_string(), ) })?; + let schema = + schema.ok_or_else(|| FlameError::InvalidState("No schema received".to_string()))?; - Ok((session_id, object_id, command, batches)) + Ok((session_id, object_id, command, schema, batches)) } fn combine_batches(batches: Vec) -> Result { @@ -848,6 +931,43 @@ fn get_object_response_schema() -> Schema { ]) } +pub fn is_native_arrow_schema(schema: &Schema) -> bool { + schema + .metadata() + .get(CACHE_FORMAT_METADATA_KEY) + .map(|format| format == CACHE_FORMAT_ARROW_TABLE) + .unwrap_or(false) +} + +fn arrow_schema_with_cache_metadata(schema: &Schema, version: u64) -> Arc { + let mut metadata = schema.metadata().clone(); + metadata.insert( + CACHE_FORMAT_METADATA_KEY.to_string(), + CACHE_FORMAT_ARROW_TABLE.to_string(), + ); + metadata.insert(CACHE_VERSION_METADATA_KEY.to_string(), version.to_string()); + Arc::new(schema.clone().with_metadata(metadata)) +} + +fn batch_with_schema(batch: &RecordBatch, schema: Arc) -> Result { + RecordBatch::try_new(schema, batch.columns().to_vec()) + .map_err(|e| FlameError::Internal(format!("Failed to attach schema metadata: {}", e))) +} + +fn version_payload(payload: ObjectPayload, version: u64) -> Result { + match payload { + ObjectPayload::Opaque(data) => Ok(ObjectPayload::Opaque(data)), + ObjectPayload::ArrowTable { schema, batches } => { + let schema = arrow_schema_with_cache_metadata(schema.as_ref(), version); + let batches = batches + .iter() + .map(|batch| batch_with_schema(batch, schema.clone())) + .collect::, _>>()?; + Ok(ObjectPayload::ArrowTable { schema, batches }) + } + } +} + #[derive(Debug, Clone, Copy)] enum ObjectResponseKind { Base, @@ -870,7 +990,7 @@ fn object_to_batch(object: &Object) -> Result { let schema = get_object_schema(); let version_array = UInt64Array::from(vec![object.version]); - let data_array = BinaryArray::from(vec![object.data.as_slice()]); + let data_array = BinaryArray::from(vec![object.opaque_data()?]); RecordBatch::try_new( Arc::new(schema), @@ -933,7 +1053,7 @@ fn object_to_response_batch( let version_array = UInt64Array::from(vec![object.version]); let kind_array = StringArray::from(vec![kind.as_str()]); - let data_array = BinaryArray::from(vec![object.data.as_slice()]); + let data_array = BinaryArray::from(vec![object.opaque_data()?]); RecordBatch::try_new( Arc::new(schema), @@ -970,6 +1090,27 @@ fn object_patches_to_flight_data_vec(patches: Vec<&Object>) -> Result, +) -> Result, FlameError> { + let schema = Arc::new(get_object_response_schema()); + let batches = rows + .into_iter() + .map(|(kind, object)| object_to_response_batch(object, kind)) + .collect::, _>>()?; + record_batches_to_flight_data_vec(schema, &batches) +} + +fn object_to_native_flight_data_vec(obj: &Object) -> Result, FlameError> { + match &obj.payload { + ObjectPayload::ArrowTable { schema, batches } => { + record_batches_to_flight_data_vec(schema.clone(), batches) + } + ObjectPayload::Opaque(_) => object_to_flight_data_vec(obj), + } +} + +fn record_batches_to_flight_data_vec( + schema: Arc, + batches: &[RecordBatch], ) -> Result, FlameError> { let options = IpcWriteOptions::default() .try_with_compression(Some(CompressionType::ZSTD)) @@ -979,8 +1120,6 @@ fn object_rows_to_flight_data_vec( let mut dict_tracker = DictionaryTracker::new(false); let mut compression_ctx = CompressionContext::default(); - let schema = Arc::new(get_object_response_schema()); - let mut all_flight_data = Vec::new(); let encoded_schema = data_gen.schema_to_bytes_with_dictionary_tracker( @@ -995,15 +1134,9 @@ fn object_rows_to_flight_data_vec( data_body: vec![].into(), }); - for (kind, object) in rows { - let delta_batch = object_to_response_batch(object, kind)?; + for batch in batches { let (encoded_dicts, encoded_batch) = data_gen - .encode( - &delta_batch, - &mut dict_tracker, - &options, - &mut compression_ctx, - ) + .encode(batch, &mut dict_tracker, &options, &mut compression_ctx) .map_err(|e| FlameError::Internal(format!("Failed to encode response batch: {}", e)))?; for dict_batch in encoded_dicts { all_flight_data.push(dict_batch.into()); @@ -1051,10 +1184,22 @@ impl FlightService for FlightCacheServer { app_metadata: Bytes::new(), }; - // Return empty schema - schema will be discovered from FlightData - // This avoids compatibility issues with schema encoding between Rust and Python + let schema = if let Ok(object_key) = ObjectKey::try_from(key.as_str()) { + match self.cache.get(&object_key).await { + Ok(object) => match object.payload { + ObjectPayload::ArrowTable { schema, .. } => { + Bytes::from(encode_schema(&schema)?) + } + ObjectPayload::Opaque(_) => Bytes::new(), + }, + Err(_) => Bytes::new(), + } + } else { + Bytes::new() + }; + let flight_info = FlightInfo { - schema: Bytes::new(), + schema, flight_descriptor: Some(FlightDescriptor { r#type: descriptor.r#type, cmd: descriptor.cmd, @@ -1126,7 +1271,24 @@ impl FlightService for FlightCacheServer { client_version, server_version ); - } else if client_version != 0 && object.version <= client_version { + } + + if object.is_arrow_table() { + tracing::debug!( + "do_get: key={}, native_arrow_rows={}, server_version={}", + key_str, + object.size_bytes(), + server_version + ); + let flight_data_vec = object_to_native_flight_data_vec(&object)?; + let stream = futures::stream::iter(flight_data_vec.into_iter().map(Ok)); + return Ok(Response::new(Box::pin(stream))); + } + + if client_version != 0 + && client_version <= server_version + && object.version <= client_version + { let needed_patches: Vec<&Object> = object .deltas .iter() @@ -1163,7 +1325,7 @@ impl FlightService for FlightCacheServer { tracing::debug!( "do_get: key={}, base_size={}, delta_count={}", key_str, - object.data.len(), + object.size_bytes(), object.deltas.len() ); @@ -1179,7 +1341,7 @@ impl FlightService for FlightCacheServer { ) -> Result, Status> { let stream = request.into_inner(); - let (key_or_prefix, object_id, command, batches) = + let (key_or_prefix, object_id, command, schema, batches) = Self::collect_batches_from_stream(stream).await?; tracing::debug!( "do_put: key_or_prefix={}, object_id={:?}, command={:?}, batch_count={}", @@ -1189,10 +1351,9 @@ impl FlightService for FlightCacheServer { batches.len() ); - let combined_batch = Self::combine_batches(batches)?; - let object = batch_to_object(&combined_batch)?; - let metadata = if command.as_deref() == Some("PATCH") { + let combined_batch = Self::combine_batches(batches)?; + let object = batch_to_object(&combined_batch)?; let key_str = match object_id { Some(oid) => format!("{}/{}", key_or_prefix, oid), None => key_or_prefix, @@ -1201,6 +1362,12 @@ impl FlightService for FlightCacheServer { let key = ObjectKey::try_from(key_str.as_str())?; self.cache.patch(&key, object).await? } else { + let object = if is_native_arrow_schema(schema.as_ref()) { + Object::new_arrow_table(0, schema, batches) + } else { + let combined_batch = Self::combine_batches(batches)?; + batch_to_object(&combined_batch)? + }; let key = ObjectKey::from_path(&key_or_prefix)?; let key = match object_id { Some(oid) => { @@ -1422,6 +1589,7 @@ pub async fn run(cache_config: &FlameCache) -> Result<(), FlameError> { #[cfg(test)] mod tests { use super::*; + use arrow::array::Int32Array; mod validation { use super::*; @@ -1507,7 +1675,7 @@ mod tests { fn new_creates_object_without_deltas() { let obj = Object::new(1, vec![1, 2, 3]); assert_eq!(obj.version, 1); - assert_eq!(obj.data, vec![1, 2, 3]); + assert_eq!(obj.opaque_data().unwrap(), &[1, 2, 3]); assert!(obj.deltas.is_empty()); } @@ -1518,10 +1686,10 @@ mod tests { let obj = Object::with_deltas(0, vec![1, 2, 3], vec![delta1.clone(), delta2.clone()]); assert_eq!(obj.version, 0); - assert_eq!(obj.data, vec![1, 2, 3]); + assert_eq!(obj.opaque_data().unwrap(), &[1, 2, 3]); assert_eq!(obj.deltas.len(), 2); - assert_eq!(obj.deltas[0].data, vec![4, 5]); - assert_eq!(obj.deltas[1].data, vec![6, 7]); + assert_eq!(obj.deltas[0].opaque_data().unwrap(), &[4, 5]); + assert_eq!(obj.deltas[1].opaque_data().unwrap(), &[6, 7]); } #[test] @@ -1537,7 +1705,32 @@ mod tests { let obj = Object::new(42, vec![10, 20, 30]); let cloned = obj.clone(); assert_eq!(cloned.version, obj.version); - assert_eq!(cloned.data, obj.data); + assert_eq!(cloned.opaque_data().unwrap(), obj.opaque_data().unwrap()); + } + + #[test] + fn new_arrow_table_creates_native_payload() { + let schema = Arc::new( + Schema::new(vec![Field::new("id", DataType::Int32, false)]).with_metadata( + [( + CACHE_FORMAT_METADATA_KEY.to_string(), + CACHE_FORMAT_ARROW_TABLE.to_string(), + )] + .into_iter() + .collect(), + ), + ); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let obj = Object::new_arrow_table(3, schema, vec![batch]); + + assert!(obj.is_arrow_table()); + assert_eq!(obj.version, 3); + assert_eq!(obj.current_version(), 3); + assert!(obj.opaque_data().is_err()); } } @@ -1616,7 +1809,10 @@ mod tests { let recovered = batch_to_object(&batch).unwrap(); assert_eq!(recovered.version, original.version); - assert_eq!(recovered.data, original.data); + assert_eq!( + recovered.opaque_data().unwrap(), + original.opaque_data().unwrap() + ); assert!(recovered.deltas.is_empty()); } @@ -1626,7 +1822,7 @@ mod tests { let batch = object_to_batch(&obj).unwrap(); let recovered = batch_to_object(&batch).unwrap(); assert_eq!(recovered.version, 0); - assert!(recovered.data.is_empty()); + assert!(recovered.opaque_data().unwrap().is_empty()); } #[test] @@ -1636,7 +1832,7 @@ mod tests { let batch = object_to_batch(&obj).unwrap(); let recovered = batch_to_object(&batch).unwrap(); assert_eq!(recovered.version, 999); - assert_eq!(recovered.data, large_data); + assert_eq!(recovered.opaque_data().unwrap(), large_data.as_slice()); } #[test] @@ -1652,7 +1848,7 @@ mod tests { let recovered = batch_to_object(&batch).unwrap(); assert_eq!(recovered.version, 0); - assert_eq!(recovered.data, b"abcdefghi"); + assert_eq!(recovered.opaque_data().unwrap(), b"abcdefghi"); } #[test] @@ -1705,7 +1901,47 @@ mod tests { let key = ObjectKey::try_from(meta.key.as_str()).unwrap(); let retrieved = cache.get(&key).await.unwrap(); assert_eq!(retrieved.version, 1); - assert_eq!(retrieved.data, vec![1, 2, 3]); + assert_eq!(retrieved.opaque_data().unwrap(), &[1, 2, 3]); + } + + #[tokio::test] + async fn put_and_get_native_arrow_table() { + let cache = create_test_cache().await; + let schema = Arc::new( + Schema::new(vec![Field::new("id", DataType::Int32, false)]).with_metadata( + [( + CACHE_FORMAT_METADATA_KEY.to_string(), + CACHE_FORMAT_ARROW_TABLE.to_string(), + )] + .into_iter() + .collect(), + ), + ); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let obj = Object::new_arrow_table(0, schema, vec![batch]); + + let key = ObjectKey::from_path("test-app/native-session").unwrap(); + let meta = cache.put(key, obj).await.unwrap(); + assert_eq!(meta.version, 1); + + let key = ObjectKey::try_from(meta.key.as_str()).unwrap(); + let retrieved = cache.get(&key).await.unwrap(); + assert_eq!(retrieved.version, 1); + match retrieved.payload { + ObjectPayload::ArrowTable { schema, batches } => { + assert_eq!( + schema.metadata().get(CACHE_VERSION_METADATA_KEY).unwrap(), + "1" + ); + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); + } + ObjectPayload::Opaque(_) => panic!("expected native Arrow payload"), + } } #[tokio::test] @@ -1731,10 +1967,10 @@ mod tests { let object = cache.get(&key).await.unwrap(); assert_eq!(object.version, 1); assert_eq!(object.current_version(), 2); - assert_eq!(object.data, b"base".to_vec()); + assert_eq!(object.opaque_data().unwrap(), b"base"); assert_eq!(object.deltas.len(), 1); assert_eq!(object.deltas[0].version, 2); - assert_eq!(object.deltas[0].data, b"patch".to_vec()); + assert_eq!(object.deltas[0].opaque_data().unwrap(), b"patch"); } #[tokio::test] @@ -2125,6 +2361,56 @@ mod tests { assert!(batches.is_empty()); } + + #[tokio::test] + async fn native_arrow_object_returns_original_schema_batches() { + let (server, _temp) = create_disk_test_server().await; + let schema = Arc::new( + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, false), + ]) + .with_metadata( + [( + CACHE_FORMAT_METADATA_KEY.to_string(), + CACHE_FORMAT_ARROW_TABLE.to_string(), + )] + .into_iter() + .collect(), + ), + ); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + let meta = server + .cache + .put( + ObjectKey::from_path("app/native").unwrap(), + Object::new_arrow_table(0, schema, vec![batch]), + ) + .await + .unwrap(); + + let batches = get_batches(&server, &format!("{}:0", meta.key)).await; + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].schema().field(0).name(), "id"); + assert_eq!(batches[0].schema().field(1).name(), "label"); + assert_eq!( + batches[0] + .schema() + .metadata() + .get(CACHE_FORMAT_METADATA_KEY) + .unwrap(), + CACHE_FORMAT_ARROW_TABLE + ); + assert_eq!(batches[0].num_rows(), 3); + } } mod flight_data_conversion { diff --git a/object_cache/src/storage/disk.rs b/object_cache/src/storage/disk.rs index 8f1559bd..db6b87be 100644 --- a/object_cache/src/storage/disk.rs +++ b/object_cache/src/storage/disk.rs @@ -24,7 +24,10 @@ use rayon::prelude::*; use common::FlameError; -use crate::cache::{Object, ObjectKey, ObjectMetadata}; +use crate::cache::{ + Object, ObjectKey, ObjectMetadata, ObjectPayload, CACHE_FORMAT_ARROW_TABLE, + CACHE_FORMAT_METADATA_KEY, CACHE_VERSION_METADATA_KEY, +}; use super::StorageEngine; @@ -78,8 +81,7 @@ impl StorageEngine for DiskStorage { tokio::task::spawn_blocking(move || { fs::create_dir_all(&session_dir)?; - let batch = object_to_batch(&object_clone)?; - write_batch_to_file(&object_path, &batch)?; + write_object_to_file(&object_path, &object_clone)?; if delta_dir.exists() { fs::remove_dir_all(&delta_dir)?; } @@ -100,7 +102,11 @@ impl StorageEngine for DiskStorage { } let base = load_object_from_file(&object_path)?; let deltas = read_deltas_sync(&delta_dir, base.version)?; - Ok(Some(Object::with_deltas(base.version, base.data, deltas))) + Ok(Some(Object::with_payload( + base.version, + base.payload, + deltas, + ))) }) .await .map_err(|e| FlameError::Internal(format!("Task join error: {}", e)))? @@ -280,7 +286,7 @@ impl StorageEngine for DiskStorage { let delta_dir = session_path.join(format!("{}.deltas", object_id)); let base = load_object_from_file(&object_path)?; let deltas = read_deltas_sync(&delta_dir, base.version)?; - let object = Object::with_deltas(base.version, base.data, deltas); + let object = Object::with_payload(base.version, base.payload, deltas); results.push((key, object)); } @@ -364,20 +370,63 @@ fn validate_delta_versions(deltas: &[Object], base_version: u64) -> Result<(), F Ok(()) } +fn write_object_to_file(path: &Path, object: &Object) -> Result<(), FlameError> { + match &object.payload { + ObjectPayload::Opaque(_) => { + let batch = object_to_batch(object)?; + write_batch_to_file(path, &batch) + } + ObjectPayload::ArrowTable { schema, batches } => { + let file = fs::File::create(path)?; + let schema = schema_with_cache_metadata(schema.as_ref(), object.version); + let batches = batches + .iter() + .map(|batch| batch_with_schema(batch, schema.clone())) + .collect::, _>>()?; + write_batches_to_writer(file, schema, &batches) + } + } +} + +fn schema_with_cache_metadata(schema: &Schema, version: u64) -> Arc { + let mut metadata = schema.metadata().clone(); + metadata.insert( + CACHE_FORMAT_METADATA_KEY.to_string(), + CACHE_FORMAT_ARROW_TABLE.to_string(), + ); + metadata.insert(CACHE_VERSION_METADATA_KEY.to_string(), version.to_string()); + Arc::new(schema.clone().with_metadata(metadata)) +} + +fn batch_with_schema(batch: &RecordBatch, schema: Arc) -> Result { + RecordBatch::try_new(schema, batch.columns().to_vec()) + .map_err(|e| FlameError::Internal(format!("Failed to attach schema metadata: {}", e))) +} + fn write_batch_to_file(path: &Path, batch: &RecordBatch) -> Result<(), FlameError> { let file = fs::File::create(path)?; - write_batch_to_writer(file, batch) + write_batches_to_writer(file, batch.schema(), std::slice::from_ref(batch)) } fn write_batch_to_writer(file: fs::File, batch: &RecordBatch) -> Result<(), FlameError> { + write_batches_to_writer(file, batch.schema(), std::slice::from_ref(batch)) +} + +fn write_batches_to_writer( + file: fs::File, + schema: Arc, + batches: &[RecordBatch], +) -> Result<(), FlameError> { let options = IpcWriteOptions::default() .try_with_compression(Some(CompressionType::ZSTD)) .map_err(|e| FlameError::Internal(format!("Failed to set compression: {}", e)))?; - let mut writer = FileWriter::try_new_with_options(file, &batch.schema(), options) + let mut writer = FileWriter::try_new_with_options(file, schema.as_ref(), options) .map_err(|e| FlameError::Internal(format!("Failed to create writer: {}", e)))?; - writer - .write(batch) - .map_err(|e| FlameError::Internal(format!("Failed to write batch: {}", e)))?; + for batch in batches { + writer + .write(batch) + .map_err(|e| FlameError::Internal(format!("Failed to write batch: {}", e)))?; + } writer .finish() .map_err(|e| FlameError::Internal(format!("Failed to finish writer: {}", e)))?; @@ -388,7 +437,7 @@ fn object_to_batch(object: &Object) -> Result { let schema = get_object_schema(); let version_array = UInt64Array::from(vec![object.version]); - let data_array = BinaryArray::from(vec![object.data.as_slice()]); + let data_array = BinaryArray::from(vec![object.opaque_data()?]); RecordBatch::try_new( Arc::new(schema), @@ -407,17 +456,35 @@ fn load_object_from_file(path: &Path) -> Result { })?; let reader = FileReader::try_new(file, None) .map_err(|e| FlameError::Internal(format!("Failed to create reader: {}", e)))?; + let schema = reader.schema(); // SAFETY: Skipping validation is safe because all data was written by this service let reader = unsafe { reader.with_skip_validation(true) }; - let batch = reader + let batches = reader .into_iter() - .next() - .ok_or_else(|| FlameError::Internal("No batches in file".to_string()))? + .collect::, _>>() .map_err(|e| FlameError::Internal(format!("Failed to read batch: {}", e)))?; - batch_to_object(&batch) + if schema + .metadata() + .get(CACHE_FORMAT_METADATA_KEY) + .map(|format| format == CACHE_FORMAT_ARROW_TABLE) + .unwrap_or(false) + { + let version = schema + .metadata() + .get(crate::cache::CACHE_VERSION_METADATA_KEY) + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + return Ok(Object::new_arrow_table(version, schema, batches)); + } + + let batch = batches + .first() + .ok_or_else(|| FlameError::Internal("No batches in file".to_string()))?; + + batch_to_object(batch) } fn batch_to_object(batch: &RecordBatch) -> Result { @@ -447,6 +514,7 @@ fn batch_to_object(batch: &RecordBatch) -> Result { #[cfg(test)] mod tests { use super::*; + use arrow::array::{Int32Array, StringArray}; use tempfile::tempdir; fn test_key(app: &str, session: &str, object: &str) -> ObjectKey { @@ -470,10 +538,57 @@ mod tests { assert!(result.is_some()); let loaded = result.unwrap(); assert_eq!(loaded.version, 1); - assert_eq!(loaded.data, vec![1, 2, 3, 4, 5]); + assert_eq!(loaded.opaque_data().unwrap(), &[1, 2, 3, 4, 5]); assert!(loaded.deltas.is_empty()); } + #[tokio::test] + async fn test_disk_storage_write_read_native_arrow_table() { + let temp_dir = tempdir().unwrap(); + let storage = DiskStorage::new(temp_dir.path().to_path_buf()).unwrap(); + + let schema = Arc::new( + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, false), + ]) + .with_metadata( + [( + CACHE_FORMAT_METADATA_KEY.to_string(), + CACHE_FORMAT_ARROW_TABLE.to_string(), + )] + .into_iter() + .collect(), + ), + ); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + let key = test_key("test-app", "test-session", "native"); + let object = Object::new_arrow_table(7, schema, vec![batch]); + + storage.write_object(&key, &object).await.unwrap(); + + let loaded = storage.read_object(&key).await.unwrap().unwrap(); + assert_eq!(loaded.version, 7); + match loaded.payload { + ObjectPayload::ArrowTable { schema, batches } => { + assert_eq!( + schema.metadata().get(CACHE_FORMAT_METADATA_KEY).unwrap(), + CACHE_FORMAT_ARROW_TABLE + ); + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); + } + ObjectPayload::Opaque(_) => panic!("expected native Arrow payload"), + } + } + #[tokio::test] async fn test_disk_storage_patch() { let temp_dir = tempdir().unwrap(); @@ -490,7 +605,7 @@ mod tests { let loaded = storage.read_object(&key).await.unwrap().unwrap(); assert_eq!(loaded.deltas.len(), 1); assert_eq!(loaded.deltas[0].version, 2); - assert_eq!(loaded.deltas[0].data, vec![4, 5, 6]); + assert_eq!(loaded.deltas[0].opaque_data().unwrap(), &[4, 5, 6]); } #[tokio::test] diff --git a/sdk/python/src/flamepy/core/cache.py b/sdk/python/src/flamepy/core/cache.py index 894d433e..c132c59c 100644 --- a/sdk/python/src/flamepy/core/cache.py +++ b/sdk/python/src/flamepy/core/cache.py @@ -41,6 +41,17 @@ _TYPE_ARROW_BATCH = b"FLM\x04" _MAGIC_PREFIX_LEN = len(_MAGIC_PREFIX) + 1 # 4 bytes total +CACHE_FORMAT_METADATA_KEY = b"flame.cache.format" +CACHE_VERSION_METADATA_KEY = b"flame.cache.version" +CACHE_LOGICAL_TYPE_METADATA_KEY = b"flame.cache.logical_type" +CACHE_FORMAT_OPAQUE = b"opaque-v1" +CACHE_FORMAT_ARROW_TABLE = b"arrow-table-v1" +LOGICAL_TYPE_PYARROW_TABLE = "pyarrow.table" +LOGICAL_TYPE_PYARROW_RECORD_BATCH = "pyarrow.record_batch" +LOGICAL_TYPE_PANDAS_DATAFRAME = "pandas.dataframe" +LOGICAL_TYPE_POLARS_DATAFRAME = "polars.dataframe" +LOGICAL_TYPE_DATASET = "dataset" + try: import numpy as np @@ -49,6 +60,22 @@ np = None # type: ignore[assignment] _HAS_NUMPY = False +try: + import pandas as pd + + _HAS_PANDAS = True +except ImportError: + pd = None # type: ignore[assignment] + _HAS_PANDAS = False + +try: + import polars as pl + + _HAS_POLARS = True +except ImportError: + pl = None # type: ignore[assignment] + _HAS_POLARS = False + logger = logging.getLogger(__name__) Deserializer = Callable[[Any, List[Any]], Any] @@ -121,6 +148,13 @@ class FetchResult: patches: List[Patch] = field(default_factory=list) +@dataclass +class CachePayload: + schema: pa.Schema + batches: List[pa.RecordBatch] + native_arrow: bool + + class _IdentityKey: __slots__ = ("value",) @@ -417,6 +451,131 @@ def _deserialize_cloudpickle(data: bytes) -> Any: return cloudpickle.loads(data) +def _cache_metadata_value(schema: pa.Schema, key: bytes) -> Optional[bytes]: + metadata = schema.metadata or {} + return metadata.get(key) + + +def _is_native_arrow_schema(schema: pa.Schema) -> bool: + return _cache_metadata_value(schema, CACHE_FORMAT_METADATA_KEY) == CACHE_FORMAT_ARROW_TABLE + + +def _cache_schema_version(schema: pa.Schema) -> int: + value = _cache_metadata_value(schema, CACHE_VERSION_METADATA_KEY) + if value is None: + return 0 + try: + return int(value.decode("utf-8")) + except ValueError: + return 0 + + +def _schema_with_cache_metadata(schema: pa.Schema, logical_type: str, version: int = 0) -> pa.Schema: + metadata = dict(schema.metadata or {}) + metadata[CACHE_FORMAT_METADATA_KEY] = CACHE_FORMAT_ARROW_TABLE + metadata[CACHE_VERSION_METADATA_KEY] = str(version).encode("utf-8") + metadata[CACHE_LOGICAL_TYPE_METADATA_KEY] = logical_type.encode("utf-8") + return schema.with_metadata(metadata) + + +def _strip_cache_metadata(schema: pa.Schema) -> pa.Schema: + metadata = {key: value for key, value in (schema.metadata or {}).items() if not key.startswith(b"flame.cache.")} + return schema.with_metadata(metadata or None) + + +def _table_with_cache_metadata(table: pa.Table, logical_type: str, version: int = 0) -> pa.Table: + return table.replace_schema_metadata(_schema_with_cache_metadata(table.schema, logical_type, version).metadata) + + +def _batch_with_cache_metadata(batch: pa.RecordBatch, logical_type: str, version: int = 0) -> pa.RecordBatch: + schema = _schema_with_cache_metadata(batch.schema, logical_type, version) + return pa.RecordBatch.from_arrays( + [batch.column(i) for i in range(batch.num_columns)], + schema=schema, + ) + + +def _prepare_native_arrow_payload(obj: Any) -> Optional[CachePayload]: + """Convert supported tabular payloads to native Arrow batches.""" + logical_type: Optional[str] = None + table: Optional[pa.Table] = None + batch: Optional[pa.RecordBatch] = None + + if isinstance(obj, pa.Table): + logical_type = LOGICAL_TYPE_PYARROW_TABLE + table = obj + elif isinstance(obj, pa.RecordBatch): + logical_type = LOGICAL_TYPE_PYARROW_RECORD_BATCH + batch = obj + elif _HAS_PANDAS and isinstance(obj, pd.DataFrame): + logical_type = LOGICAL_TYPE_PANDAS_DATAFRAME + table = pa.Table.from_pandas(obj) + elif _HAS_POLARS and isinstance(obj, pl.DataFrame): + logical_type = LOGICAL_TYPE_POLARS_DATAFRAME + table = obj.to_arrow() + elif _HAS_POLARS and isinstance(obj, pl.LazyFrame): + logical_type = LOGICAL_TYPE_POLARS_DATAFRAME + table = obj.collect().to_arrow() + elif hasattr(obj, "to_arrow"): + try: + arrow_obj = obj.to_arrow() + except TypeError: + arrow_obj = None + if isinstance(arrow_obj, pa.Table): + logical_type = LOGICAL_TYPE_DATASET + table = arrow_obj + elif isinstance(arrow_obj, pa.RecordBatch): + logical_type = LOGICAL_TYPE_DATASET + batch = arrow_obj + + if batch is not None and logical_type is not None: + batch = _batch_with_cache_metadata(batch, logical_type) + return CachePayload(schema=batch.schema, batches=[batch], native_arrow=True) + + if table is not None and logical_type is not None: + table = _table_with_cache_metadata(table, logical_type) + batches = table.to_batches() + if not batches: + batches = [ + pa.RecordBatch.from_arrays( + [table.column(i).combine_chunks() for i in range(table.num_columns)], + schema=table.schema, + ) + ] + return CachePayload(schema=table.schema, batches=batches, native_arrow=True) + + return None + + +def _prepare_cache_payload(obj: Any) -> CachePayload: + native_payload = _prepare_native_arrow_payload(obj) + if native_payload is not None: + return native_payload + batch = _serialize_object(obj) + return CachePayload(schema=batch.schema, batches=[batch], native_arrow=False) + + +def _deserialize_native_arrow_table(table: pa.Table) -> Any: + metadata = table.schema.metadata or {} + logical_type = metadata.get(CACHE_LOGICAL_TYPE_METADATA_KEY, LOGICAL_TYPE_PYARROW_TABLE.encode("utf-8")).decode("utf-8") + user_schema = _strip_cache_metadata(table.schema) + table = table.replace_schema_metadata(user_schema.metadata) + + if logical_type == LOGICAL_TYPE_PYARROW_RECORD_BATCH: + batches = table.to_batches() + if len(batches) == 1: + return batches[0] + return table + + if logical_type == LOGICAL_TYPE_PANDAS_DATAFRAME and _HAS_PANDAS: + return table.to_pandas() + + if logical_type == LOGICAL_TYPE_POLARS_DATAFRAME and _HAS_POLARS: + return pl.from_arrow(table) + + return table + + def _serialize_object_data(obj: Any) -> bytes: """Serialize object using the optimal format based on type. @@ -579,13 +738,20 @@ def _get_flight_client_with_retry(endpoint: str, tls_config: Optional[FlameClien return _get_flight_client(endpoint, tls_config) -def _do_put_remote(client: flight.FlightClient, descriptor: flight.FlightDescriptor, batch: pa.RecordBatch) -> "ObjectRef": +def _do_put_remote_batches( + client: flight.FlightClient, + descriptor: flight.FlightDescriptor, + schema: pa.Schema, + batches: List[pa.RecordBatch], + options: Optional[flight.FlightCallOptions] = None, +) -> "ObjectRef": """Perform a remote do_put operation and read the result metadata. Args: client: Arrow Flight client descriptor: Flight descriptor for the put operation - batch: RecordBatch to upload + schema: Arrow schema to upload + batches: RecordBatches to upload Returns: ObjectRef received from the server @@ -593,10 +759,13 @@ def _do_put_remote(client: flight.FlightClient, descriptor: flight.FlightDescrip Raises: ValueError: If metadata cannot be read from server """ - writer, reader = client.do_put(descriptor, batch.schema) + if options is None: + writer, reader = client.do_put(descriptor, schema) + else: + writer, reader = client.do_put(descriptor, schema, options) - # Write batch - writer.write_batch(batch) + for batch in batches: + writer.write_batch(batch) # Signal we're done writing writer.done_writing() @@ -625,6 +794,10 @@ def _do_put_remote(client: flight.FlightClient, descriptor: flight.FlightDescrip raise ValueError("No result metadata received from cache server") +def _do_put_remote(client: flight.FlightClient, descriptor: flight.FlightDescriptor, batch: pa.RecordBatch) -> "ObjectRef": + return _do_put_remote_batches(client, descriptor, batch.schema, [batch]) + + def _get_cache_tls_config() -> Optional[FlameClientTls]: """Get TLS configuration for cache from FlameContext. @@ -683,7 +856,7 @@ def put_object(key_prefix: str, obj: Any) -> "ObjectRef": if not cache_endpoint: raise ValueError("Cache endpoint not configured") - batch = _serialize_object(obj) + payload = _prepare_cache_payload(obj) storage_path: Optional[Path] = None use_local_storage = False @@ -706,8 +879,9 @@ def put_object(key_prefix: str, obj: Any) -> "ObjectRef": app_session_dir.mkdir(parents=True, exist_ok=True) object_path = app_session_dir / f"{object_key_with_id.object_id}.arrow" - writer = pa.ipc.new_file(object_path, batch.schema) - writer.write_batch(batch) + writer = pa.ipc.new_file(object_path, payload.schema) + for batch in payload.batches: + writer.write_batch(batch) writer.close() client = _get_flight_client(cache_endpoint, cache_tls) @@ -725,7 +899,7 @@ def put_object(key_prefix: str, obj: Any) -> "ObjectRef": else: client = _get_flight_client(cache_endpoint, cache_tls) upload_descriptor = flight.FlightDescriptor.for_path(object_key.to_prefix()) - ref = _do_put_remote(client, upload_descriptor, batch) + ref = _do_put_remote_batches(client, upload_descriptor, payload.schema, payload.batches) logger.debug(f"put_object remote: key={ref.key}, version={ref.version}") return ref @@ -810,6 +984,13 @@ def _fetch_object_data(ref: ObjectRef, cached_version: int) -> Optional[FetchRes reader = client.do_get(ticket) table = reader.read_all() + if _is_native_arrow_schema(table.schema): + return FetchResult( + mode=FetchMode.FULL, + version=_cache_schema_version(table.schema), + base=_deserialize_native_arrow_table(table), + ) + if table.num_rows == 0: return None @@ -887,13 +1068,13 @@ def update_object(ref: ObjectRef, new_obj: Any) -> "ObjectRef": """ ObjectKey.from_key(ref.key) - batch = _serialize_object(new_obj) + payload = _prepare_cache_payload(new_obj) tls_config = _get_cache_tls_config() client = _get_flight_client(ref.endpoint, tls_config) upload_descriptor = flight.FlightDescriptor.for_path(ref.key) - new_ref = _do_put_remote(client, upload_descriptor, batch) + new_ref = _do_put_remote_batches(client, upload_descriptor, payload.schema, payload.batches) _cache_remove((ref.endpoint, ref.key)) diff --git a/sdk/python/tests/test_cache.py b/sdk/python/tests/test_cache.py index af0a5b4e..f67b6fb0 100644 --- a/sdk/python/tests/test_cache.py +++ b/sdk/python/tests/test_cache.py @@ -25,7 +25,9 @@ _cache_lock, _deserialize_object, _deserialize_object_data, + _is_native_arrow_schema, _object_cache, + _prepare_cache_payload, _serialize_object, _serialize_object_data, delete_objects, @@ -111,6 +113,15 @@ def test_pyarrow_table_uses_fast_path(self): result = _deserialize_object_data(data) assert result.equals(table) + def test_pyarrow_table_uses_native_cache_payload(self): + table = pa.table({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + payload = _prepare_cache_payload(table) + + assert payload.native_arrow is True + assert _is_native_arrow_schema(payload.schema) + assert payload.batches + assert payload.schema.metadata[b"flame.cache.logical_type"] == b"pyarrow.table" + def test_pyarrow_record_batch_uses_fast_path(self): batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]}) data = _serialize_object_data(batch) @@ -119,6 +130,33 @@ def test_pyarrow_record_batch_uses_fast_path(self): result = _deserialize_object_data(data) assert result.equals(batch) + def test_pyarrow_record_batch_uses_native_cache_payload(self): + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]}) + payload = _prepare_cache_payload(batch) + + assert payload.native_arrow is True + assert _is_native_arrow_schema(payload.schema) + assert payload.batches[0].to_pydict() == batch.to_pydict() + assert payload.schema.metadata[b"flame.cache.logical_type"] == b"pyarrow.record_batch" + + def test_pandas_dataframe_uses_native_cache_payload_when_available(self): + pd = pytest.importorskip("pandas") + dataframe = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + payload = _prepare_cache_payload(dataframe) + + assert payload.native_arrow is True + assert _is_native_arrow_schema(payload.schema) + assert payload.schema.metadata[b"flame.cache.logical_type"] == b"pandas.dataframe" + + def test_polars_dataframe_uses_native_cache_payload_when_available(self): + pl = pytest.importorskip("polars") + dataframe = pl.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + payload = _prepare_cache_payload(dataframe) + + assert payload.native_arrow is True + assert _is_native_arrow_schema(payload.schema) + assert payload.schema.metadata[b"flame.cache.logical_type"] == b"polars.dataframe" + def test_pyarrow_array_uses_fast_path(self): arr = pa.array([1, 2, 3, 4, 5]) data = _serialize_object_data(arr) @@ -590,6 +628,30 @@ def test_fetch_object_data_parses_full_response_rows(self, monkeypatch): assert result.base == [1] assert [patch.data for patch in result.patches] == [[2], [3]] + def test_fetch_object_data_parses_native_arrow_table(self, monkeypatch): + from flamepy.core import cache as cache_module + + table = pa.table({"x": [1, 2, 3], "y": ["a", "b", "c"]}) + table = table.replace_schema_metadata( + { + b"flame.cache.format": b"arrow-table-v1", + b"flame.cache.version": b"7", + b"flame.cache.logical_type": b"pyarrow.table", + } + ) + self._patch_fetch_client(monkeypatch, table) + + result = cache_module._fetch_object_data( + ObjectRef(endpoint="grpc://host:9090", key="app/session/native-table", version=1), + 0, + ) + + assert result.mode == FetchMode.FULL + assert result.version == 7 + assert isinstance(result.base, pa.Table) + assert result.base.to_pydict() == {"x": [1, 2, 3], "y": ["a", "b", "c"]} + assert b"flame.cache.format" not in (result.base.schema.metadata or {}) + def test_fetch_object_data_rejects_base_after_patch(self, monkeypatch): from flamepy.core import cache as cache_module @@ -643,6 +705,52 @@ def worker(i): class TestPutObject: + def test_remote_put_sends_native_arrow_payload(self, monkeypatch): + from flamepy.core import cache as cache_module + from flamepy.core.types import FlameClientCache + + captured = {"schema": None, "batches": []} + + class MockWriter: + def write_batch(self, batch): + captured["batches"].append(batch) + + def done_writing(self): + pass + + def close(self): + pass + + class MockReader: + def __init__(self): + self._read_count = 0 + + def read(self): + if self._read_count == 0: + self._read_count += 1 + return bson.encode({"endpoint": "grpc://host:9090", "key": "app/session/native", "version": 1}) + return None + + class MockFlightClient: + def do_put(self, descriptor, schema): + captured["schema"] = schema + return MockWriter(), MockReader() + + class MockContext: + cache = FlameClientCache(endpoint="grpc://host:9090") + + monkeypatch.setattr(cache_module, "_context_cache", None) + monkeypatch.setattr(cache_module, "FlameContext", lambda: MockContext()) + monkeypatch.setattr(cache_module, "_get_flight_client", lambda ep, tls=None: MockFlightClient()) + + ref = cache_module.put_object("app/session", pa.table({"x": [1, 2, 3]})) + + assert ref.key == "app/session/native" + assert captured["schema"].metadata[b"flame.cache.format"] == b"arrow-table-v1" + assert captured["schema"].metadata[b"flame.cache.logical_type"] == b"pyarrow.table" + assert len(captured["batches"]) == 1 + assert captured["batches"][0].to_pydict() == {"x": [1, 2, 3]} + def test_local_storage_ref_uses_cacheable_version(self, monkeypatch, tmp_path): from flamepy.core import cache as cache_module from flamepy.core.types import FlameClientCache @@ -666,6 +774,7 @@ class MockContext: storage=str(tmp_path), ) + monkeypatch.setattr(cache_module, "_context_cache", None) monkeypatch.setattr(cache_module, "FlameContext", lambda: MockContext()) monkeypatch.setattr( cache_module, From 5b7752b2f6ec9a98f563f0220d90f5def610868d Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Mon, 18 May 2026 18:47:24 -0700 Subject: [PATCH 2/2] Address native cache review feedback --- object_cache/src/cache.rs | 129 +++++++++++++++++++++++++-- object_cache/src/storage/disk.rs | 59 ++++++++++++ object_cache/src/storage/mod.rs | 5 ++ object_cache/src/storage/none.rs | 7 ++ sdk/python/src/flamepy/core/cache.py | 7 +- sdk/python/tests/test_cache.py | 23 +++++ 6 files changed, 219 insertions(+), 11 deletions(-) diff --git a/object_cache/src/cache.rs b/object_cache/src/cache.rs index 9f2f9ae2..7a993c1d 100644 --- a/object_cache/src/cache.rs +++ b/object_cache/src/cache.rs @@ -614,6 +614,24 @@ impl ObjectCache { ))) } + async fn schema(&self, key: &ObjectKey) -> Result>, FlameError> { + let key_str = key.to_key().ok_or_else(|| { + FlameError::InvalidConfig("ObjectKey requires object_id for schema lookup".to_string()) + })?; + + { + let objects = lock_ptr!(self.objects)?; + if let Some(object) = objects.get(&key_str) { + return Ok(match &object.payload { + ObjectPayload::ArrowTable { schema, .. } => Some(schema.clone()), + ObjectPayload::Opaque(_) => None, + }); + } + } + + self.storage.read_schema(key).await + } + async fn patch(&self, key: &ObjectKey, delta: Object) -> Result { let key_str = key.to_key().ok_or_else(|| { FlameError::InvalidConfig("ObjectKey requires object_id for patch".to_string()) @@ -1185,13 +1203,9 @@ impl FlightService for FlightCacheServer { }; let schema = if let Ok(object_key) = ObjectKey::try_from(key.as_str()) { - match self.cache.get(&object_key).await { - Ok(object) => match object.payload { - ObjectPayload::ArrowTable { schema, .. } => { - Bytes::from(encode_schema(&schema)?) - } - ObjectPayload::Opaque(_) => Bytes::new(), - }, + match self.cache.schema(&object_key).await { + Ok(Some(schema)) => Bytes::from(encode_schema(&schema)?), + Ok(None) => Bytes::new(), Err(_) => Bytes::new(), } } else { @@ -2161,8 +2175,61 @@ mod tests { use super::*; use futures::StreamExt; use std::path::Path; + use std::sync::atomic::{AtomicUsize, Ordering}; use tempfile::tempdir; + struct SchemaOnlyStorage { + schema: Arc, + read_object_calls: Arc, + } + + #[async_trait] + impl crate::storage::StorageEngine for SchemaOnlyStorage { + async fn write_object( + &self, + _key: &ObjectKey, + _object: &Object, + ) -> Result<(), FlameError> { + Ok(()) + } + + async fn read_object(&self, _key: &ObjectKey) -> Result, FlameError> { + self.read_object_calls.fetch_add(1, Ordering::SeqCst); + Err(FlameError::Internal( + "read_object should not be used for schema lookup".to_string(), + )) + } + + async fn read_schema( + &self, + _key: &ObjectKey, + ) -> Result>, FlameError> { + Ok(Some(self.schema.clone())) + } + + async fn patch_object( + &self, + key: &ObjectKey, + _delta: &Object, + ) -> Result { + Ok(ObjectMetadata { + endpoint: String::new(), + key: key.to_key().unwrap_or_default(), + version: 0, + size: 0, + delta_count: 0, + }) + } + + async fn delete_objects(&self, _key: &ObjectKey) -> Result<(), FlameError> { + Ok(()) + } + + async fn load_objects(&self) -> Result, FlameError> { + Ok(Vec::new()) + } + } + async fn create_disk_test_server() -> (FlightCacheServer, tempfile::TempDir) { let endpoint = test_endpoint(); let temp = tempdir().unwrap(); @@ -2207,6 +2274,54 @@ mod tests { .unwrap() } + #[tokio::test] + async fn get_flight_info_reads_native_schema_without_loading_object() { + let schema = Arc::new( + Schema::new(vec![Field::new("id", DataType::Int32, false)]).with_metadata( + [( + CACHE_FORMAT_METADATA_KEY.to_string(), + CACHE_FORMAT_ARROW_TABLE.to_string(), + )] + .into_iter() + .collect(), + ), + ); + let read_object_calls = Arc::new(AtomicUsize::new(0)); + let storage = Box::new(SchemaOnlyStorage { + schema, + read_object_calls: read_object_calls.clone(), + }); + let cache = Arc::new(ObjectCache::new(test_endpoint(), storage, None).unwrap()); + let server = FlightCacheServer::new(cache); + + let response = server + .get_flight_info(Request::new(FlightDescriptor::new_path(vec![ + "app/session/native".to_string(), + ]))) + .await + .unwrap(); + let info = response.into_inner(); + + assert_eq!(read_object_calls.load(Ordering::SeqCst), 0); + assert!(!info.schema.is_empty()); + + let decoded_schema = FlightCacheServer::extract_schema_from_flight_data(&FlightData { + flight_descriptor: None, + app_metadata: Bytes::new(), + data_header: info.schema, + data_body: Bytes::new(), + }) + .unwrap(); + assert_eq!(decoded_schema.field(0).name(), "id"); + assert_eq!( + decoded_schema + .metadata() + .get(CACHE_FORMAT_METADATA_KEY) + .unwrap(), + CACHE_FORMAT_ARROW_TABLE + ); + } + async fn get_batches(server: &FlightCacheServer, ticket: &str) -> Vec { let response = server .do_get(Request::new(Ticket { diff --git a/object_cache/src/storage/disk.rs b/object_cache/src/storage/disk.rs index db6b87be..7bdbf341 100644 --- a/object_cache/src/storage/disk.rs +++ b/object_cache/src/storage/disk.rs @@ -112,6 +112,19 @@ impl StorageEngine for DiskStorage { .map_err(|e| FlameError::Internal(format!("Task join error: {}", e)))? } + async fn read_schema(&self, key: &ObjectKey) -> Result>, FlameError> { + let object_path = self.object_path(key); + + tokio::task::spawn_blocking(move || { + if !object_path.exists() { + return Ok(None); + } + load_native_schema_from_file(&object_path) + }) + .await + .map_err(|e| FlameError::Internal(format!("Task join error: {}", e)))? + } + async fn patch_object( &self, key: &ObjectKey, @@ -487,6 +500,30 @@ fn load_object_from_file(path: &Path) -> Result { batch_to_object(batch) } +fn load_native_schema_from_file(path: &Path) -> Result>, FlameError> { + let file = fs::File::open(path).map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + FlameError::NotFound(format!("Object file not found: {}", path.display())) + } else { + FlameError::Internal(format!("Failed to open object file: {}", e)) + } + })?; + let reader = FileReader::try_new(file, None) + .map_err(|e| FlameError::Internal(format!("Failed to create reader: {}", e)))?; + let schema = reader.schema(); + + if schema + .metadata() + .get(CACHE_FORMAT_METADATA_KEY) + .map(|format| format == CACHE_FORMAT_ARROW_TABLE) + .unwrap_or(false) + { + Ok(Some(schema)) + } else { + Ok(None) + } +} + fn batch_to_object(batch: &RecordBatch) -> Result { if batch.num_rows() != 1 { return Err(FlameError::InvalidState( @@ -574,6 +611,16 @@ mod tests { storage.write_object(&key, &object).await.unwrap(); + let stored_schema = storage.read_schema(&key).await.unwrap().unwrap(); + assert_eq!(stored_schema.field(0).name(), "id"); + assert_eq!( + stored_schema + .metadata() + .get(CACHE_FORMAT_METADATA_KEY) + .unwrap(), + CACHE_FORMAT_ARROW_TABLE + ); + let loaded = storage.read_object(&key).await.unwrap().unwrap(); assert_eq!(loaded.version, 7); match loaded.payload { @@ -589,6 +636,18 @@ mod tests { } } + #[tokio::test] + async fn test_disk_storage_read_schema_returns_none_for_opaque_object() { + let temp_dir = tempdir().unwrap(); + let storage = DiskStorage::new(temp_dir.path().to_path_buf()).unwrap(); + + let key = test_key("test-app", "test-session", "opaque"); + let object = Object::new(1, vec![1, 2, 3, 4, 5]); + storage.write_object(&key, &object).await.unwrap(); + + assert!(storage.read_schema(&key).await.unwrap().is_none()); + } + #[tokio::test] async fn test_disk_storage_patch() { let temp_dir = tempdir().unwrap(); diff --git a/object_cache/src/storage/mod.rs b/object_cache/src/storage/mod.rs index 0525d5d5..4ccaaf8c 100644 --- a/object_cache/src/storage/mod.rs +++ b/object_cache/src/storage/mod.rs @@ -18,7 +18,9 @@ pub use disk::DiskStorage; pub use none::NoneStorage; use std::path::PathBuf; +use std::sync::Arc; +use arrow::datatypes::Schema; use async_trait::async_trait; use common::FlameError; @@ -36,6 +38,9 @@ pub trait StorageEngine: Send + Sync + 'static { /// Read an object with all deltas. Returns None if not found. async fn read_object(&self, key: &ObjectKey) -> Result, FlameError>; + /// Read a native Arrow object schema without loading record batches. + async fn read_schema(&self, key: &ObjectKey) -> Result>, FlameError>; + /// Append a delta to an existing object. Returns NotFound if base doesn't exist. async fn patch_object( &self, diff --git a/object_cache/src/storage/none.rs b/object_cache/src/storage/none.rs index 5508b035..6faa60ce 100644 --- a/object_cache/src/storage/none.rs +++ b/object_cache/src/storage/none.rs @@ -44,6 +44,13 @@ impl StorageEngine for NoneStorage { Ok(None) } + async fn read_schema( + &self, + _key: &ObjectKey, + ) -> Result>, FlameError> { + Ok(None) + } + async fn patch_object( &self, key: &ObjectKey, diff --git a/sdk/python/src/flamepy/core/cache.py b/sdk/python/src/flamepy/core/cache.py index c132c59c..0a6eb6dc 100644 --- a/sdk/python/src/flamepy/core/cache.py +++ b/sdk/python/src/flamepy/core/cache.py @@ -879,10 +879,9 @@ def put_object(key_prefix: str, obj: Any) -> "ObjectRef": app_session_dir.mkdir(parents=True, exist_ok=True) object_path = app_session_dir / f"{object_key_with_id.object_id}.arrow" - writer = pa.ipc.new_file(object_path, payload.schema) - for batch in payload.batches: - writer.write_batch(batch) - writer.close() + with pa.ipc.new_file(object_path, payload.schema) as writer: + for batch in payload.batches: + writer.write_batch(batch) client = _get_flight_client(cache_endpoint, cache_tls) descriptor = flight.FlightDescriptor.for_path(key) diff --git a/sdk/python/tests/test_cache.py b/sdk/python/tests/test_cache.py index f67b6fb0..cafff31a 100644 --- a/sdk/python/tests/test_cache.py +++ b/sdk/python/tests/test_cache.py @@ -755,6 +755,24 @@ def test_local_storage_ref_uses_cacheable_version(self, monkeypatch, tmp_path): from flamepy.core import cache as cache_module from flamepy.core.types import FlameClientCache + writer_calls = {"entered": 0, "exited": 0, "batches": 0, "path": None} + + class MockFileWriter: + def __enter__(self): + writer_calls["entered"] += 1 + return self + + def __exit__(self, exc_type, exc, traceback): + writer_calls["exited"] += 1 + return False + + def write_batch(self, batch): + writer_calls["batches"] += 1 + + def mock_new_file(path, schema): + writer_calls["path"] = path + return MockFileWriter() + class MockLocation: uri = "grpc://local-cache:9090" @@ -781,11 +799,16 @@ class MockContext: "_get_flight_client", lambda endpoint, tls=None: MockFlightClient(), ) + monkeypatch.setattr(cache_module.pa.ipc, "new_file", mock_new_file) ref = cache_module.put_object("app/session", {"value": 1}) assert ref.endpoint == "grpc://local-cache:9090" assert ref.version == 1 + assert writer_calls["entered"] == 1 + assert writer_calls["exited"] == 1 + assert writer_calls["batches"] == 1 + assert writer_calls["path"].parent == tmp_path / "app" / "session" class TestDeleteObjects: