From b654f964f4a1e3217effe88b4444ca09f812d47e Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 12 May 2026 07:21:49 +0800 Subject: [PATCH 1/5] feat: add incremental object get --- .../RFE445-incremental-object-get/FS.md | 600 ++++++++++++++++++ e2e/tests/test_cache.py | 95 +++ examples/rl/replay_buffer/main.py | 93 ++- examples/rl/replay_buffer/replay_buffer.py | 11 +- object_cache/src/cache.rs | 393 ++++++++++-- object_cache/src/storage/disk.rs | 45 +- sdk/python/src/flamepy/core/cache.py | 159 ++++- sdk/python/tests/test_cache.py | 231 ++++++- 8 files changed, 1537 insertions(+), 90 deletions(-) create mode 100644 docs/designs/RFE445-incremental-object-get/FS.md diff --git a/docs/designs/RFE445-incremental-object-get/FS.md b/docs/designs/RFE445-incremental-object-get/FS.md new file mode 100644 index 00000000..9433f91a --- /dev/null +++ b/docs/designs/RFE445-incremental-object-get/FS.md @@ -0,0 +1,600 @@ +# RFE445: Incremental Object Retrieval by Patch Version + +## 1. Motivation + +**Background:** + +Flame object cache already supports versioned objects and append-only patches. Today, `get_object(ref)` passes a client-side version to the cache server. If the server version matches, the server returns an empty response and flamepy serves the local cached object. If the server version differs, the server returns the full base object plus all patches. + +That behavior is correct but expensive for patch-heavy objects. A replay buffer is the clearest example: collectors append transition batches with `patch_object`, while a buffer service repeatedly calls `get_object` for state and sampling. Once the buffer service has version `N`, the next read after versions `N+1..M` only needs those new patches if the base object has not changed. Returning the base object and all historical patches every time makes network transfer and deserialization grow with total buffer history instead of with new work since the previous read. + +**Target:** + +Add an incremental retrieval path for versioned `get_object`: + +- If the client version equals the current server version, return not-modified. +- If the client version is older but the server base snapshot is still valid, return only patches after the client version. +- If the server base object version is newer than the client version, or the patch history cannot bridge from the client version, return the full base object plus patches. + +The Python SDK should hide this behind the existing `get_object(ref, deserializer=None)` API and update its local cache so callers do not need a new method. + +Success is measured by replay-buffer read-path improvements: fewer bytes downloaded, lower `get_object` latency for state/sample calls, and better or unchanged end-to-end transitions per second. + +## 2. Function Specification + +### Configuration + +No configuration changes are required. + +Incremental retrieval is controlled only by the client version sent to `get_object`: + +| Request Version | Behavior | +|-----------------|----------| +| `0` | Return the full base object plus all patches. | +| `> 0` | Return not-modified, patches after that version, or a full response when patch-only retrieval is not safe. | + +### API + +No public Python API change: + +```python +def get_object(ref: ObjectRef, deserializer: Optional[Deserializer] = None) -> Any: + ... +``` + +`ObjectRef.version` remains the latest server object version known at the time the ref was produced. `version=0` keeps its current force-refresh behavior. + +For normal reads, flamepy should choose the effective request version this way: + +- Send `0` when `ref.version == 0`. +- Send `0` when there is no local cached object, even if `ref.version > 0`. +- Send the local cached object version when a cached object exists and `ref.version > 0`. + +The caller does not need to mutate an `ObjectRef` after every read. + +### CLI + +No production CLI changes are required for incremental object retrieval. + +Replay-buffer performance evaluation should add example-only flags to `examples/rl/replay_buffer/main.py`: + +- `--metrics-json `: write per-iteration performance metrics. +- `--merge-every `: configure compaction cadence. +- `--no-merge`: stress patch-only reads by keeping a long patch history. +- `--force-full-get`: force replay-buffer reads to send request version `0` for baseline measurement. + +These flags should exit with the existing process status semantics: nonzero on uncaught workload failure, zero when all configured iterations complete. + +### Other Interfaces + +#### Version Model + +The design keeps one monotonically increasing object version. There is no new public version field and no new object API. + +That single version is used consistently: + +| Location | Meaning | +|----------|---------| +| `ObjectRef.version` | Current object version returned to Python callers. | +| `ObjectMetadata.version` | Current object version tracked by the cache server. | +| Base row `Object.version` | Version when the current base snapshot was written. | +| Patch row `Object.version` | Version assigned when that patch was appended. | + +Mutation rules: + +| Operation | Version Behavior | +|-----------|------------------| +| Existing `put_object` for a new key | Current version becomes `1`; base row version is `1`. | +| Existing `update_object` / replacing put | Current version increments, base row is written with that new version, patches are cleared. | +| Existing `patch_object` | Current version increments, patch row is written with that new version, base row is unchanged. | + +The object cache must persist patch versions. On storage reload, metadata can be reconstructed from the base row and ordered patch rows: + +```text +current_version = max(base.version, last_patch.version) +delta_count = number of patches +``` + +#### Retrieval Protocol + +Object retrieval stays on Arrow Flight `do_get`. The cache server currently lists only `DELETE` as a Flight action, so this design does not add a `GET` action or an alternate action. + +The existing ticket format remains the protocol: + +```text +: +``` + +Keys are already restricted to `//`, so `:` is safe as a ticket separator. + +The `do_get` response uses a versioned Arrow schema that lets flamepy distinguish full responses from patch-only responses. This schema is for the read response stream; persistent storage can keep its existing `version,data` row shape. + +```text +version: uint64 +kind: utf8 # "base" or "patch" +data: binary +``` + +Implementation should centralize these schema field names and row-kind values as constants or enums on both the Rust server and Python client. + +Response framing: + +- Not modified: empty stream with this schema and zero rows. +- Full response: exactly one `kind="base"` row first, followed by zero or more `kind="patch"` rows in increasing version order. +- Patch-only response: one or more `kind="patch"` rows in increasing version order and no base row. + +Response modes: + +| Server State | Response | +|--------------|----------| +| `client_version == current_version` and `client_version != 0` | Empty stream with schema: not modified. | +| `base.version <= client_version < current_version` and all needed patch versions exist | Patch rows where `patch.version > client_version`. | +| `client_version == 0` | Full response: base row plus all patch rows. | +| `client_version < base.version` | Full response: base row plus all patch rows. | +| Patch history is missing, compacted, or non-contiguous | Full response: base row plus all patch rows. | +| `client_version > current_version` | Full response and warning log. Treat the client cache as suspect instead of returning not-modified. | + +The client computes the returned current version as the maximum row version. For a full response, the base row and patch rows replace the cached entry. For a patch-only response, the returned patch rows advance the cached current version. + +Patch-only responses require a valid local cached base. flamepy enforces this by sending request version `0` whenever no local cache entry exists. + +### Scope + +**In Scope:** + +- Single monotonically increasing object version across base writes and patches. +- Assigning server versions to persisted patch rows. +- Version-driven incremental behavior for the existing `do_get` ticket format. +- Full-object retrieval when the request version is `0`. +- flamepy cache changes to store base data, patch data, and materialized results. +- Unit tests and focused E2E coverage for full, patch-only, and not-modified paths. +- Replay buffer performance evaluation plan. + +**Out of Scope:** + +- Public Python API changes. +- Cross-process client cache sharing. +- Client cache persistence to disk. +- Arbitrary patch compaction policy changes beyond returning a full response when history cannot bridge. +- Incremental deserializer/reducer APIs. Existing deserializers still receive base plus the full cached patch list. +- Optimistic write conflict detection. + +**Limitations:** + +- Patch-only retrieval only helps clients that already have a valid local cache entry. +- If the base object is replaced or compacted past the client version, the server must send a full response. +- Existing deserializers still define how base plus patches become user data. Incremental fetch reduces network and repeated deserialization of old patch payloads; it does not automatically make every deserializer incrementally composable. + +### Feature Interaction + +**Related Features:** + +- RFE318 object cache and patch semantics. +- RFE426 object versioning and client-side cache. +- RFE423 app/session cache key validation. +- `examples/rl/replay_buffer`, which is the target workload for performance evaluation. + +**Updates Required:** + +- `object_cache/src/cache.rs`: update existing `do_get` version handling with current-version and base-row-version decisions, plus patch-only response streaming. +- `object_cache/src/storage/disk.rs`: persist server-assigned patch versions and reconstruct current version from base plus patch history. +- `sdk/python/src/flamepy/core/cache.py`: cache base, versioned patches, and materialized outputs; send request version `0` for full fetches and cached nonzero current versions for incremental reads. +- `sdk/python/tests/test_cache.py` and E2E cache tests: add full, patch-only, not-modified, and update-after-cache coverage. +- `examples/rl/replay_buffer/main.py`: add benchmark flags and metrics export for the performance evaluation. + +**Integration Points:** + +- Arrow Flight `do_get` remains the transport for object reads. +- `do_put` with path descriptors remains the transport for put/update writes. +- `do_put` with `PATCH:` command descriptors remains the transport for patch writes. +- `ObjectRef.version` continues to expose the server object version to Python callers. +- Replay-buffer `state()`, `sample()`, and `merge()` continue to call `get_object(..., deserializer=...)`. + +**Compatibility:** + +- No backward compatibility with old flamepy/cache-server wire behavior is required. +- The public Python `get_object` API remains unchanged. +- `version=0` always requests a full object. +- `version>0` always enables the version-aware behavior described in this document. +- flamepy sends `0` when it has no local cache entry, so first reads and forced refreshes remain full-object downloads. + +**Breaking Changes:** + +The existing `do_get` response format can change as needed so updated flamepy can distinguish full responses from patch-only responses. + +## 3. Implementation Detail + +### Architecture + +```text +flamepy get_object + | + | local cached entry? + | no -> key:0 -> full base + patches + | yes -> key:version -> not_modified, patch rows, or full rows + v +object cache + | + | compares client_version with base row version and current version + v +returns minimal safe stream +``` + +The important invariant is that patch-only rows are returned only when the client's cached base is still the same base snapshot the server is extending. + +### Components + +**`object_cache/src/cache.rs`** + +- Track `ObjectMetadata.version` as the current object version. +- Use the existing base object's `version` field to detect whether the client's cached base is still valid. +- Set patch object version to the new current version before appending. +- Load metadata from storage before comparing versions in `do_get`, so process restarts do not make version checks look like `0`. +- Parse the existing `:` tickets. +- Return full, patch-only, or empty streams according to the retrieval protocol. +- Produce a consistent per-object read snapshot. Implementation should hold the per-key lock while choosing the current version and loading base/patch rows, or verify after loading that the selected current version still matches the loaded row set. + +**`object_cache/src/storage/disk.rs`** + +- Persist patch row versions by writing the server-assigned patch object, not the client-uploaded version `0`. +- Read patches in order and preserve their versions. +- Reconstruct current version and `delta_count` after reload. +- Treat old stored patch rows with version `0` as pre-RFE445 rows. On load, synthesize contiguous patch versions in filename order from `base.version + 1` through `base.version + delta_count`; future writes persist server-assigned versions. + +**`sdk/python/src/flamepy/core/cache.py`** + +- Replace the cached `Object(version, data)` shape with a richer cache entry: + +```python +@dataclass +class Patch: + version: int + data: Any + +@dataclass +class CachedObject: + version: int + base: Any + patches: list[Patch] + materialized: dict[int | None, Any] +``` + +- Request full or incremental data by choosing the effective request version. +- Parse response rows by `kind`. +- On full response, replace `base`, `patches`, and `version`. +- On patch-only response, require a cached entry, append patches, update `version`, and invalidate materialized values. +- On not-modified, return the materialized cached result. + +Materialization: + +- `deserializer is None`: return the base object, preserving the public API behavior. +- `deserializer is not None`: compute `deserializer(base, [patch.data for patch in patches])`. +- Cache materialized values by deserializer identity within the process. New patches invalidate those values. + +Mutations: + +- `put_object`, `update_object`, `patch_object`, and `delete_objects` continue to invalidate affected local cache entries. +- `patch_object` should return the new current version from server metadata. It should not mutate a cached object optimistically; the next `get_object` remains the consistency point. + +### Data Structures + +Server-side object rows: + +```rust +pub struct Object { + pub version: u64, // base row version or patch row version + pub data: Vec, + pub deltas: Vec, +} + +pub struct ObjectMetadata { + pub endpoint: String, + pub key: String, + pub version: u64, // current object version + pub size: u64, + pub delta_count: u64, +} +``` + +Patch history must be ordered by version, not just by file index. File index can remain the storage layout as long as reads validate monotonic patch versions. + +### Algorithms + +**Server mutation path:** + +```text +put/update: + current_version = metadata.version or max loaded row version or 0 + new_version = current_version + 1 + write base object with version = new_version + clear patches + metadata.version = new_version + +patch: + current_version = metadata.version or max loaded row version + new_version = current_version + 1 + write patch object with version = new_version + metadata.version = new_version +``` + +**Server Versioned Get:** + +```text +acquire per-key read/write coordination + +load base + patches +current_version = metadata.version or max loaded row version + +if client_version != 0 and client_version == current_version: + return empty stream + +needed_patches = patches with version > client_version +patch_suffix_is_contiguous = + needed_patches versions are exactly client_version + 1 through current_version + +if client_version != 0 + and object.version <= client_version + and patch_suffix_is_contiguous: + return needed_patches + +return base row + all patch rows +``` + +**Client `get_object`:** + +```text +cache_key = (endpoint, key) +cached = local_cache.get(cache_key) + +if ref.version == 0 or cached is None: + client_version = 0 +else: + client_version = cached.version + +response = fetch_object_data(ref, client_version) + +if response.not_modified: + return materialize(cached, deserializer) + +if response.full: + cached = replace_from_full(response) + local_cache[cache_key] = cached + return materialize(cached, deserializer) + +if response.patches: + if cached is None: + request with client_version = 0 + append patches to cached + cached.version = max_patch_version + cached.materialized.clear() + return materialize(cached, deserializer) +``` + +### System Considerations + +**Performance:** + +Patch-only get changes read transfer cost from: + +```text +O(size(base) + size(all_patches)) +``` + +to: + +```text +O(size(new_patches_since_last_read)) +``` + +for clients whose base snapshot is still valid. + +**Scalability:** + +This benefits workloads with one or more long-lived readers and frequent append-only writers. It does not change write concurrency semantics. + +**Reliability:** + +Returning a full response is the safety mechanism. If the server cannot prove the requested patch suffix is valid, it sends the full object. `do_get` must return rows from one consistent per-object snapshot; concurrent `patch_object` or `update_object` calls must not produce a response whose base row, patch rows, and advertised current version disagree. + +**Resource Usage:** + +Client memory may increase because flamepy stores base and deserialized patches instead of only one materialized value. Keep the existing LRU entry limit and count each object entry once. Future work can add byte-based client cache limits. + +**Security:** + +No new auth surface. Ticket parsing must continue to validate `ObjectKey` before reading storage. + +**Observability:** + +Add debug logs and optional counters for: + +- `get_object_not_modified_total` +- `get_object_patch_response_total` +- `get_object_full_response_total` +- response bytes and row counts +- client cache hit/miss counts +- patch upload latency and bytes +- deserializer/materialization latency + +These counters are also useful for the replay-buffer evaluation. + +**Operational:** + +Deploy the cache server and flamepy changes together because the internal `do_get` response schema changes. No runtime switch is required: `version=0` requests full data, and nonzero versions use the incremental behavior. Operators can force full-object reads for debugging or A/B measurement through a benchmark-only path that sends request version `0`. + +### Dependencies + +No new external dependencies are required. + +Internal dependencies: + +- Arrow Flight `do_get` streaming in `object_cache/src/cache.rs`. +- Existing flamepy serialization/deserialization helpers in `sdk/python/src/flamepy/core/cache.py`. +- Existing replay buffer example in `examples/rl/replay_buffer`. + +Version requirements: + +- No package version bump is required for the design itself. +- Implementation should include the change in the next flamepy release because client/server wire behavior changes internally. + +### Verification Plan + +#### Rust Unit Tests + +- Existing `put_object` for a new key creates current object version `1`. +- Patch assigns patch row version `current_version + 1`. +- Update replaces base, clears patches, and advances current object version. +- Nonzero versioned get returns empty for matching current version. +- Nonzero versioned get returns only patches after client version when base is valid. +- Nonzero versioned get returns full object when `client_version < base.version`. +- Nonzero versioned get returns full object when patch history has a gap. +- `client_version=0` always returns the full object. +- Full response framing is base row first, then patches in increasing version order. +- Patch-only response framing contains only patches in increasing version order. +- Concurrent patch/update during get cannot return an inconsistent row set. +- Metadata reconstructed from disk preserves current object version and patch versions. +- Pre-RFE445 stored patch rows with version `0` are assigned synthetic contiguous versions during load. + +#### Python Unit Tests + +- `get_object` replaces cache on full response. +- `get_object` appends patch-only response to an existing cache entry. +- Patch-only response without local cache retries full. +- Not-modified response returns cached materialized data. +- `version=0` bypasses cache and requests full response. +- Materialized cache invalidates after new patches. +- Existing `deserializer=None` behavior still returns base object only. + +#### E2E Tests + +- Put base, patch twice, read from one Python process twice, verify the second read fetches only the second patch. +- Put base, read, update base, read again, verify full response and correct final data. +- Replay buffer smoke test with metrics enabled verifies full/patch/not-modified counters are internally consistent. + +## 4. Use Cases + +### Basic Use Cases + +#### Example 1: Repeated Read with New Patches + +1. Buffer service reads replay buffer at version `10`; flamepy caches the base row and patches through `10`. +2. Collectors append patches `11`, `12`, and `13`. +3. Buffer service calls `get_object` again. +4. flamepy sends `:10`. +5. Server sees `base.version=1`, current version `13`, and returns only patches `11..13`. +6. flamepy appends those patches locally and materializes the replay buffer. + +Expected outcome: network transfer is proportional to the three new patch batches, not the entire replay buffer history. + +#### Example 2: Base Replaced + +1. Client caches object at version `10`. +2. Another actor calls `update_object`, creating a new base at version `11`. +3. Client calls `get_object`. +4. Server sees `base.version=11 > client_version=10`. +5. Server returns full base plus any patches after version `11`. + +Expected outcome: the client never applies patches to an obsolete base. + +#### Example 3: Not Modified + +1. Client caches current version `20`. +2. Client calls `get_object` before any mutation. +3. Server returns empty stream. +4. flamepy returns the cached materialized value. + +Expected outcome: same behavior as current RFE426 cache hits. + +### Advanced Use Cases + +#### Replay Buffer Performance Evaluation + +**Metrics:** + +Primary read-path metrics: + +- Total bytes downloaded by `get_object`. +- `get_object` latency for `ReplayBuffer.state()` and `ReplayBuffer.sample()`. +- Deserializer/materialization CPU time inside `_fetch()`. +- Base size, patch count, and patch rows downloaded per read. +- Number of full, patch-only, and not-modified responses. + +Write-path metrics: + +- `ReplayBuffer.push()` latency, including `patch_object`. +- Bytes uploaded per collector patch. +- Patch failure count. +- Server `delta_count` before and after collection. + +End-to-end metrics: + +- Total runtime. +- Transitions per second. +- Collection latency per iteration. +- Failed collector calls. +- Merge latency when `ReplayBuffer.merge()` runs. + +**Benchmark Method:** + +Primary comparison: run the same replay-buffer workload twice on the same cluster and commit: + +```shell +uv run main.py --force-full-get --iterations 50 --collections 20 --steps-per-collection 500 --batch-size 64 +uv run main.py --iterations 50 --collections 20 --steps-per-collection 500 --batch-size 64 +``` + +The first run forces request version `0` on replay-buffer reads. The second run uses normal nonzero cached versions after the first full fetch. This isolates the proposed `get_object` behavior without a runtime switch. + +Add a replay-buffer benchmark mode before collecting final numbers: + +- `--metrics-json `: write per-iteration metrics. +- `--merge-every `: make the current hard-coded merge interval configurable. +- `--no-merge`: stress the patch-only path by keeping a long patch history. +- `--force-full-get`: force replay-buffer reads to send request version `0` for baseline measurement. + +Run at least these cases: + +| Case | Purpose | +|------|---------| +| Default merge every 5 iterations | Measures realistic current example behavior. | +| No merge | Maximizes accumulated patches and should show largest read-path improvement. | +| Merge every iteration | Confirms no regression when full responses are naturally small. | + +Secondary comparison: keep the collector work and iteration shape identical, then compare the patch-based replay buffer with a controlled non-patch variant: + +- Service-only push path: collectors send transitions to the `ReplayBuffer` service, which owns an in-memory list. +- Whole-object update path: writers fetch, append, and `update_object` the complete replay buffer. + +This secondary comparison is not required to prove version-driven incremental get, but it shows where patch-based replay-buffer design sits against simpler alternatives. + +**Success Criteria:** + +- Patch-heavy replay-buffer reads download at least 70% fewer bytes after the first cached read. +- Median `state()`/`sample()` read latency improves by at least 30% in the no-merge case. +- End-to-end transitions/sec improves in the no-merge case or remains within 5% in cases dominated by environment stepping. +- `patch_object` write latency does not regress by more than 5%. +- No correctness difference in final `size`, `total_added`, sampled batch sizes, failed collector count, or episode counts. + +## 5. References + +**Related Documents:** + +- GitHub issue: https://github.com/xflops/flame/issues/445 +- `docs/designs/RFE318-cache/FS.md` +- `docs/designs/RFE426-cache-versioning/FS.md` +- `docs/designs/RFE423-app-cache-key/FS.md` + +**External References:** + +- Apache Arrow Flight protocol: https://arrow.apache.org/docs/format/Flight.html +- PyArrow Flight API: https://arrow.apache.org/docs/python/api/flight.html + +**Implementation References:** + +- `object_cache/src/cache.rs` +- `object_cache/src/storage/disk.rs` +- `sdk/python/src/flamepy/core/cache.py` +- `sdk/python/tests/test_cache.py` +- `examples/rl/replay_buffer/replay_buffer.py` +- `examples/rl/replay_buffer/main.py` +- `examples/rl/replay_buffer/README.md` diff --git a/e2e/tests/test_cache.py b/e2e/tests/test_cache.py index f7866c0b..ff56e210 100644 --- a/e2e/tests/test_cache.py +++ b/e2e/tests/test_cache.py @@ -14,6 +14,8 @@ import time import uuid +import flamepy.core.cache as cache_module +import pyarrow.flight as flight import pytest from flamepy.core import FlameContext, ObjectRef, get_object, patch_object, put_object, update_object @@ -98,6 +100,27 @@ def _raw_deserializer(base, deltas): return {"base": base, "deltas": deltas} +def _remote_patch_without_local_cache_invalidation(ref: ObjectRef, delta): + """Patch through Flight directly to emulate another client process.""" + batch = cache_module._serialize_object(delta) + client = cache_module._get_flight_client(ref.endpoint, cache_module._get_cache_tls_config()) + descriptor = flight.FlightDescriptor.for_command(f"PATCH:{ref.key}".encode()) + return cache_module._do_put_remote(client, descriptor, batch) + + +def _remote_update_without_local_cache_invalidation(ref: ObjectRef, new_obj): + """Update through Flight directly to emulate another client process.""" + batch = cache_module._serialize_object(new_obj) + client = cache_module._get_flight_client(ref.endpoint, cache_module._get_cache_tls_config()) + descriptor = flight.FlightDescriptor.for_path(ref.key) + return cache_module._do_put_remote(client, descriptor, batch) + + +def _cached_object(ref: ObjectRef): + with cache_module._cache_lock: + return cache_module._object_cache[(ref.endpoint, ref.key)] + + def test_patch_single_delta(): """Test patching an object with a single delta.""" key_prefix = "test-app/test-patch-001" @@ -141,6 +164,78 @@ def test_patch_multiple_deltas(): assert result["deltas"][2] == delta3 +def test_incremental_get_applies_remote_patch_only_response(): + """Test a cached client applies only patches appended by another client.""" + key_prefix = f"test-app/test-incremental-patch-{uuid.uuid4().hex[:8]}" + base_data = {"items": ["base"]} + delta_data = {"items": ["patch-1"]} + + ref = put_object(key_prefix, base_data) + assert get_object(ref, deserializer=_raw_deserializer) == {"base": base_data, "deltas": []} + + patched_ref = _remote_patch_without_local_cache_invalidation(ref, delta_data) + fetch_result = cache_module._fetch_object_data(ref, ref.version) + + assert fetch_result.mode == cache_module.FetchMode.PATCHES + assert fetch_result.version == patched_ref.version + assert [patch.data for patch in fetch_result.patches] == [delta_data] + + result = get_object(ref, deserializer=_raw_deserializer) + assert result == {"base": base_data, "deltas": [delta_data]} + + cached = _cached_object(ref) + assert cached.version == patched_ref.version + assert [patch.data for patch in cached.patches] == [delta_data] + + +def test_version_zero_forces_full_response_with_cached_object(): + """Test version=0 gets the full base plus patches even with a local cache.""" + key_prefix = f"test-app/test-incremental-full-{uuid.uuid4().hex[:8]}" + base_data = {"items": ["base"]} + delta_data = {"items": ["patch-1"]} + + ref = put_object(key_prefix, base_data) + assert get_object(ref, deserializer=_raw_deserializer) == {"base": base_data, "deltas": []} + patched_ref = _remote_patch_without_local_cache_invalidation(ref, delta_data) + + forced_ref = ObjectRef(endpoint=ref.endpoint, key=ref.key, version=0) + fetch_result = cache_module._fetch_object_data(forced_ref, 0) + + assert fetch_result.mode == cache_module.FetchMode.FULL + assert fetch_result.version == patched_ref.version + assert fetch_result.base == base_data + assert [patch.data for patch in fetch_result.patches] == [delta_data] + + result = get_object(forced_ref, deserializer=_raw_deserializer) + assert result == {"base": base_data, "deltas": [delta_data]} + + +def test_incremental_get_falls_back_to_full_after_remote_update(): + """Test stale cached base is replaced by a full response after update.""" + key_prefix = f"test-app/test-incremental-update-{uuid.uuid4().hex[:8]}" + base_data = {"version": 1} + updated_data = {"version": 2} + + ref = put_object(key_prefix, base_data) + assert get_object(ref, deserializer=_raw_deserializer) == {"base": base_data, "deltas": []} + + updated_ref = _remote_update_without_local_cache_invalidation(ref, updated_data) + fetch_result = cache_module._fetch_object_data(ref, ref.version) + + assert fetch_result.mode == cache_module.FetchMode.FULL + assert fetch_result.version == updated_ref.version + assert fetch_result.base == updated_data + assert fetch_result.patches == [] + + result = get_object(ref, deserializer=_raw_deserializer) + assert result == {"base": updated_data, "deltas": []} + + cached = _cached_object(ref) + assert cached.version == updated_ref.version + assert cached.data == updated_data + assert cached.patches == [] + + def test_patch_preserves_delta_order(): """Test that deltas are returned in the order they were appended.""" key_prefix = "test-app/test-patch-003" diff --git a/examples/rl/replay_buffer/main.py b/examples/rl/replay_buffer/main.py index 86a24a71..af711ba0 100644 --- a/examples/rl/replay_buffer/main.py +++ b/examples/rl/replay_buffer/main.py @@ -17,7 +17,11 @@ def run_distributed( num_collections: int = 20, steps_per_collection: int = 500, batch_size: int = 64, + merge_every: int | None = 5, + metrics_json: str | None = None, + force_full_get: bool = False, ): + import json import time from flamepy.runner import Runner @@ -31,30 +35,43 @@ def run_distributed( print(f" Steps per collection: {steps_per_collection}") print(f" Iterations: {num_iterations}") print(f" Batch size: {batch_size}") + print(f" Merge every: {merge_every if merge_every else 'disabled'}") + print(f" Force full get: {force_full_get}") print("\nStarting distributed collection...") start_time = time.time() + metrics = [] + total_added = 0 with Runner(f"replay-buffer-{env_name.lower()}") as rr: - buffer = ReplayBuffer(rr) + buffer = ReplayBuffer(rr, force_full_get=force_full_get) buffer_svc = rr.service(buffer, autoscale=False, warmup=1) collector = rr.service(Collector(env_name), autoscale=True) for iteration in range(num_iterations): + iteration_start = time.time() collect_futures = [ collector.collect(buffer, steps_per_collection) for _ in range(num_collections) ] collect_results = rr.get(collect_futures) + collect_elapsed = time.time() - iteration_start - if iteration % 5 == 4: + merge_elapsed = 0.0 + if merge_every and iteration % merge_every == merge_every - 1: + merge_start = time.time() buffer_svc.merge().wait() + merge_elapsed = time.time() - merge_start + state_start = time.time() stats = buffer_svc.state().get() + state_elapsed = time.time() - state_start total_size = stats["size"] total_added = stats["total_added"] total_episodes = sum(r["episode_count"] for r in collect_results) - avg_reward = sum(r["avg_reward"] * r["episode_count"] for r in collect_results) / max(1, total_episodes) + avg_reward = sum( + r["avg_reward"] * r["episode_count"] for r in collect_results + ) / max(1, total_episodes) print( f"Iteration {iteration:2d} | " @@ -63,10 +80,30 @@ def run_distributed( f"Avg Reward: {avg_reward:7.1f}" ) + sample_elapsed = 0.0 + sampled = 0 if total_size >= batch_size: + sample_start = time.time() batch = buffer_svc.sample(batch_size).get() + sample_elapsed = time.time() - sample_start + sampled = len(batch) print(f" | Sampled batch of {len(batch)} transitions") + metrics.append( + { + "iteration": iteration, + "collect_secs": collect_elapsed, + "merge_secs": merge_elapsed, + "state_secs": state_elapsed, + "sample_secs": sample_elapsed, + "buffer_size": total_size, + "total_added": total_added, + "total_episodes": total_episodes, + "avg_reward": avg_reward, + "sampled": sampled, + } + ) + elapsed = time.time() - start_time print("\n" + "=" * 60) print("Collection Complete!") @@ -75,6 +112,30 @@ def run_distributed( print(f" Throughput: {total_added / elapsed:.1f} transitions/sec") print("=" * 60) + if metrics_json: + with open(metrics_json, "w") as f: + json.dump( + { + "configuration": { + "env": env_name, + "iterations": num_iterations, + "collections": num_collections, + "steps_per_collection": steps_per_collection, + "batch_size": batch_size, + "merge_every": merge_every, + "force_full_get": force_full_get, + }, + "summary": { + "total_time_secs": elapsed, + "total_transitions": total_added, + "throughput": total_added / elapsed, + }, + "iterations": metrics, + }, + f, + indent=2, + ) + def run_local( env_name: str = "CartPole-v1", @@ -187,8 +248,31 @@ def main(): parser.add_argument( "--batch-size", type=int, default=64, help="Batch size for sampling" ) + parser.add_argument( + "--metrics-json", + type=str, + default=None, + help="Write distributed-mode metrics to a JSON file", + ) + parser.add_argument( + "--merge-every", + type=int, + default=5, + help="Merge replay-buffer patches every N iterations", + ) + parser.add_argument( + "--no-merge", + action="store_true", + help="Disable replay-buffer patch merging", + ) + parser.add_argument( + "--force-full-get", + action="store_true", + help="Force replay-buffer reads to request full objects with version 0", + ) args = parser.parse_args() + merge_every = None if args.no_merge else args.merge_every if args.local: run_local( @@ -204,6 +288,9 @@ def main(): num_collections=args.collections, steps_per_collection=args.steps_per_collection, batch_size=args.batch_size, + merge_every=merge_every, + metrics_json=args.metrics_json, + force_full_get=args.force_full_get, ) diff --git a/examples/rl/replay_buffer/replay_buffer.py b/examples/rl/replay_buffer/replay_buffer.py index ebecb46b..efe35de3 100644 --- a/examples/rl/replay_buffer/replay_buffer.py +++ b/examples/rl/replay_buffer/replay_buffer.py @@ -6,10 +6,12 @@ class ReplayBuffer: - def __init__(self, rr: "Runner"): - from flamepy.core import get_object, patch_object, update_object + def __init__(self, rr: "Runner", force_full_get: bool = False): + from flamepy.core import ObjectRef, get_object, patch_object, update_object self.buffer_ref = rr.put_object({"transitions": [], "total_added": 0}) + self.force_full_get = force_full_get + self._object_ref = ObjectRef self._get_object = get_object self._update_object = update_object self._patch_object = patch_object @@ -24,7 +26,10 @@ def _deserializer(self, base: dict, deltas: List) -> dict: } def _fetch(self) -> dict: - return self._get_object(self.buffer_ref, deserializer=self._deserializer) + ref = self.buffer_ref + if self.force_full_get: + ref = self._object_ref(endpoint=ref.endpoint, key=ref.key, version=0) + return self._get_object(ref, deserializer=self._deserializer) def push(self, transitions: List[dict]) -> None: self._patch_object(self.buffer_ref, transitions) diff --git a/object_cache/src/cache.rs b/object_cache/src/cache.rs index 2da63afa..e65a7338 100644 --- a/object_cache/src/cache.rs +++ b/object_cache/src/cache.rs @@ -15,7 +15,7 @@ use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; -use arrow::array::{BinaryArray, RecordBatch, UInt64Array}; +use arrow::array::{BinaryArray, RecordBatch, StringArray, UInt64Array}; use arrow::compute::concat_batches; use arrow::datatypes::{DataType, Field, Schema}; use arrow::ipc::writer::{ @@ -237,6 +237,15 @@ impl Object { deltas, } } + + pub fn current_version(&self) -> u64 { + self.deltas + .iter() + .map(|delta| delta.version) + .max() + .unwrap_or(self.version) + .max(self.version) + } } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -368,7 +377,7 @@ impl ObjectCache { for (key, object) in items { let key_str = key.to_key().expect("loaded key must have object_id"); let size = object.data.len() as u64; - let version = object.version; + let version = object.current_version(); let delta_count = object.deltas.len() as u64; let meta = self.create_metadata(key_str.clone(), version, size, delta_count); @@ -454,7 +463,7 @@ impl ObjectCache { .storage .read_object(&key) .await? - .map(|obj| obj.version) + .map(|obj| obj.current_version()) .unwrap_or(0), }; let new_version = current_version + 1; @@ -499,7 +508,7 @@ impl ObjectCache { if let Some(object) = self.storage.read_object(key).await? { let size = object.data.len() as u64; let delta_count = object.deltas.len() as u64; - let version = object.version; + let version = object.current_version(); { let mut objects = lock_ptr!(self.objects)?; @@ -546,14 +555,15 @@ impl ObjectCache { .storage .read_object(key) .await? - .map(|obj| obj.version) + .map(|obj| obj.current_version()) .ok_or_else(|| { FlameError::NotFound(format!("object <{}> not found for patch", key_str)) })?, }; let new_version = current_version + 1; - let mut meta = self.storage.patch_object(key, &delta).await?; + let versioned_delta = Object::new(new_version, delta.data); + let mut meta = self.storage.patch_object(key, &versioned_delta).await?; meta.endpoint = self.endpoint.to_uri(); meta.version = new_version; @@ -807,15 +817,46 @@ fn encode_schema(schema: &Schema) -> Result, FlameError> { Ok(encoded.ipc_message) } +#[cfg(test)] fn get_object_schema() -> Schema { Schema::new(vec![ - Field::new("version", DataType::UInt64, false), - Field::new("data", DataType::Binary, false), + Field::new(OBJECT_RESPONSE_FIELD_VERSION, DataType::UInt64, false), + Field::new(OBJECT_RESPONSE_FIELD_DATA, DataType::Binary, false), + ]) +} + +const OBJECT_RESPONSE_FIELD_VERSION: &str = "version"; +const OBJECT_RESPONSE_FIELD_KIND: &str = "kind"; +const OBJECT_RESPONSE_FIELD_DATA: &str = "data"; +const OBJECT_RESPONSE_KIND_BASE: &str = "base"; +const OBJECT_RESPONSE_KIND_PATCH: &str = "patch"; + +fn get_object_response_schema() -> Schema { + Schema::new(vec![ + Field::new(OBJECT_RESPONSE_FIELD_VERSION, DataType::UInt64, false), + Field::new(OBJECT_RESPONSE_FIELD_KIND, DataType::Utf8, false), + Field::new(OBJECT_RESPONSE_FIELD_DATA, DataType::Binary, false), ]) } +#[derive(Debug, Clone, Copy)] +enum ObjectResponseKind { + Base, + Patch, +} + +impl ObjectResponseKind { + fn as_str(self) -> &'static str { + match self { + Self::Base => OBJECT_RESPONSE_KIND_BASE, + Self::Patch => OBJECT_RESPONSE_KIND_PATCH, + } + } +} + // Helper function to create a RecordBatch from object data // Note: Only serializes version and data; deltas are stored separately +#[cfg(test)] fn object_to_batch(object: &Object) -> Result { let schema = get_object_schema(); @@ -856,7 +897,7 @@ fn batch_to_object(batch: &RecordBatch) -> Result { } fn create_empty_flight_data() -> Result, FlameError> { - let schema = get_object_schema(); + let schema = get_object_response_schema(); let options = IpcWriteOptions::default(); let data_gen = IpcDataGenerator::default(); let mut dict_tracker = DictionaryTracker::new(false); @@ -872,10 +913,52 @@ fn create_empty_flight_data() -> Result, FlameError> { }]) } +fn object_to_response_batch( + object: &Object, + kind: ObjectResponseKind, +) -> Result { + let schema = get_object_response_schema(); + + let version_array = UInt64Array::from(vec![object.version]); + let kind_array = StringArray::from(vec![kind.as_str()]); + let data_array = BinaryArray::from(vec![object.data.as_slice()]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(version_array), + Arc::new(kind_array), + Arc::new(data_array), + ], + ) + .map_err(|e| FlameError::Internal(format!("Failed to create response RecordBatch: {}", e))) +} + /// Convert Object (with deltas) to FlightData stream /// Sends schema once, followed by base batch, then delta batches /// Uses ZSTD compression for ~54% faster encoding (Arrow 58+) fn object_to_flight_data_vec(obj: &Object) -> Result, FlameError> { + let mut rows = Vec::with_capacity(obj.deltas.len() + 1); + rows.push((ObjectResponseKind::Base, obj)); + rows.extend( + obj.deltas + .iter() + .map(|delta| (ObjectResponseKind::Patch, delta)), + ); + object_rows_to_flight_data_vec(rows) +} + +fn object_patches_to_flight_data_vec(patches: &[Object]) -> Result, FlameError> { + let rows = patches + .iter() + .map(|delta| (ObjectResponseKind::Patch, delta)) + .collect(); + object_rows_to_flight_data_vec(rows) +} + +fn object_rows_to_flight_data_vec( + rows: Vec<(ObjectResponseKind, &Object)>, +) -> Result, FlameError> { let options = IpcWriteOptions::default() .try_with_compression(Some(CompressionType::ZSTD)) .map_err(|e| FlameError::Internal(format!("Failed to set compression: {}", e)))?; @@ -884,8 +967,7 @@ fn object_to_flight_data_vec(obj: &Object) -> Result, FlameError let mut dict_tracker = DictionaryTracker::new(false); let mut compression_ctx = CompressionContext::default(); - let base_batch = object_to_batch(obj)?; - let schema = base_batch.schema(); + let schema = Arc::new(get_object_response_schema()); let mut all_flight_data = Vec::new(); @@ -901,21 +983,8 @@ fn object_to_flight_data_vec(obj: &Object) -> Result, FlameError data_body: vec![].into(), }); - let (encoded_dicts, encoded_batch) = data_gen - .encode( - &base_batch, - &mut dict_tracker, - &options, - &mut compression_ctx, - ) - .map_err(|e| FlameError::Internal(format!("Failed to encode base batch: {}", e)))?; - for dict_batch in encoded_dicts { - all_flight_data.push(dict_batch.into()); - } - all_flight_data.push(encoded_batch.into()); - - for delta in &obj.deltas { - let delta_batch = object_to_batch(delta)?; + for (kind, object) in rows { + let delta_batch = object_to_response_batch(object, kind)?; let (encoded_dicts, encoded_batch) = data_gen .encode( &delta_batch, @@ -923,7 +992,7 @@ fn object_to_flight_data_vec(obj: &Object) -> Result, FlameError &options, &mut compression_ctx, ) - .map_err(|e| FlameError::Internal(format!("Failed to encode delta batch: {}", e)))?; + .map_err(|e| FlameError::Internal(format!("Failed to encode response batch: {}", e)))?; for dict_batch in encoded_dicts { all_flight_data.push(dict_batch.into()); } @@ -1010,19 +1079,21 @@ impl FlightService for FlightCacheServer { let key = ObjectKey::try_from(key_str.as_str())?; - let (server_version, metadata_keys) = { - let metadata = lock_ptr!(self.cache.metadata) - .map_err(|e| Status::internal(format!("Lock error: {}", e)))?; - let version = metadata.get(&key_str).map(|m| m.version).unwrap_or(0); - let keys: Vec = metadata.keys().cloned().collect(); - (version, keys) - }; + let write_lock = self + .cache + .get_write_lock(&key_str) + .map_err(|e| Status::internal(format!("Lock error: {}", e)))?; + let _guard = write_lock.lock().await; + + let object = self.cache.get(&key).await?; + let server_version = object.current_version(); tracing::debug!( - "do_get: key={}, server_version={}, metadata_keys_count={}", + "do_get: key={}, server_version={}, base_version={}, delta_count={}", key_str, server_version, - metadata_keys.len() + object.version, + object.deltas.len() ); if client_version != 0 && server_version == client_version { @@ -1036,7 +1107,47 @@ impl FlightService for FlightCacheServer { return Ok(Response::new(Box::pin(stream))); } - let object = self.cache.get(&key).await?; + if client_version > server_version { + tracing::warn!( + "do_get: key={}, client_version={} is greater than server_version={}, returning full object", + key_str, + client_version, + server_version + ); + } else if client_version != 0 && object.version <= client_version { + let needed_patches: Vec = object + .deltas + .iter() + .filter(|delta| delta.version > client_version) + .cloned() + .collect(); + let expected_patch_count = server_version.saturating_sub(client_version) as usize; + let patch_suffix_is_contiguous = needed_patches.len() == expected_patch_count + && needed_patches + .iter() + .enumerate() + .all(|(idx, delta)| delta.version == client_version + idx as u64 + 1); + + if patch_suffix_is_contiguous { + tracing::debug!( + "do_get: key={}, patch_only_count={}, client_version={}, server_version={}", + key_str, + needed_patches.len(), + client_version, + server_version + ); + let flight_data_vec = object_patches_to_flight_data_vec(&needed_patches)?; + let stream = futures::stream::iter(flight_data_vec.into_iter().map(Ok)); + return Ok(Response::new(Box::pin(stream))); + } + + tracing::debug!( + "do_get: key={}, patch suffix unavailable (client_version={}, server_version={}), returning full object", + key_str, + client_version, + server_version + ); + } tracing::debug!( "do_get: key={}, base_size={}, delta_count={}", @@ -1150,7 +1261,7 @@ impl FlightService for FlightCacheServer { &self, _request: Request, ) -> Result, Status> { - let schema = get_object_schema(); + let schema = get_object_response_schema(); let schema_result = SchemaResult { schema: Bytes::from(encode_schema(&schema)?), @@ -1663,6 +1774,212 @@ mod tests { } } + mod versioned_get { + use super::*; + use futures::StreamExt; + use std::path::Path; + use tempfile::tempdir; + + async fn create_disk_test_server() -> (FlightCacheServer, tempfile::TempDir) { + let endpoint = test_endpoint(); + let temp = tempdir().unwrap(); + let storage = + Box::new(crate::storage::DiskStorage::new(temp.path().to_path_buf()).unwrap()); + let cache = Arc::new(ObjectCache::new(endpoint, storage, None).unwrap()); + (FlightCacheServer::new(cache), temp) + } + + fn test_endpoint() -> CacheEndpoint { + CacheEndpoint { + scheme: "grpc".to_string(), + host: "localhost".to_string(), + port: 9090, + } + } + + async fn create_disk_test_server_from_path(path: &Path) -> FlightCacheServer { + let storage = Box::new(crate::storage::DiskStorage::new(path.to_path_buf()).unwrap()); + let cache = Arc::new(ObjectCache::new(test_endpoint(), storage, None).unwrap()); + cache.load_from_storage().await.unwrap(); + FlightCacheServer::new(cache) + } + + async fn put_and_patch(server: &FlightCacheServer) -> ObjectMetadata { + let key = ObjectKey::from_path("app/session").unwrap(); + let meta = server + .cache + .put(key, Object::new(0, b"base".to_vec())) + .await + .unwrap(); + let key = ObjectKey::try_from(meta.key.as_str()).unwrap(); + server + .cache + .patch(&key, Object::new(0, b"patch-1".to_vec())) + .await + .unwrap(); + server + .cache + .patch(&key, Object::new(0, b"patch-2".to_vec())) + .await + .unwrap() + } + + async fn get_batches(server: &FlightCacheServer, ticket: &str) -> Vec { + let response = server + .do_get(Request::new(Ticket { + ticket: Bytes::from(ticket.as_bytes().to_vec()), + })) + .await + .unwrap(); + let mut stream = response.into_inner(); + let mut schema: Option> = None; + let mut batches = Vec::new(); + + while let Some(item) = stream.next().await { + let flight_data = item.unwrap(); + if schema.is_none() && !flight_data.data_header.is_empty() { + schema = Some( + FlightCacheServer::extract_schema_from_flight_data(&flight_data).unwrap(), + ); + } + if !flight_data.data_body.is_empty() { + let schema_ref = schema.as_ref().unwrap(); + batches.push( + FlightCacheServer::decode_batch_from_flight_data(&flight_data, schema_ref) + .unwrap(), + ); + } + } + + batches + } + + fn row_kind(batch: &RecordBatch) -> String { + batch + .column_by_name(OBJECT_RESPONSE_FIELD_KIND) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .to_string() + } + + fn row_version(batch: &RecordBatch) -> u64 { + batch + .column_by_name(OBJECT_RESPONSE_FIELD_VERSION) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + } + + fn row_data(batch: &RecordBatch) -> Vec { + batch + .column_by_name(OBJECT_RESPONSE_FIELD_DATA) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .to_vec() + } + + #[tokio::test] + async fn client_version_zero_returns_full_response() { + let (server, _temp) = create_disk_test_server().await; + let meta = put_and_patch(&server).await; + + let batches = get_batches(&server, &format!("{}:0", meta.key)).await; + + assert_eq!(batches.len(), 3); + assert_eq!(row_kind(&batches[0]), ObjectResponseKind::Base.as_str()); + assert_eq!(row_kind(&batches[1]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_kind(&batches[2]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_version(&batches[0]), 1); + assert_eq!(row_version(&batches[1]), 2); + assert_eq!(row_version(&batches[2]), 3); + } + + #[tokio::test] + async fn stale_client_version_returns_patch_only_response() { + let (server, _temp) = create_disk_test_server().await; + let meta = put_and_patch(&server).await; + + let batches = get_batches(&server, &format!("{}:1", meta.key)).await; + + assert_eq!(batches.len(), 2); + assert_eq!(row_kind(&batches[0]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_kind(&batches[1]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_version(&batches[0]), 2); + assert_eq!(row_version(&batches[1]), 3); + } + + #[tokio::test] + async fn stale_client_version_after_reload_returns_patch_only_response() { + let (server, temp) = create_disk_test_server().await; + let meta = put_and_patch(&server).await; + let reloaded_server = create_disk_test_server_from_path(temp.path()).await; + + let batches = get_batches(&reloaded_server, &format!("{}:1", meta.key)).await; + + assert_eq!(batches.len(), 2); + assert_eq!(row_kind(&batches[0]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_kind(&batches[1]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_version(&batches[0]), 2); + assert_eq!(row_version(&batches[1]), 3); + assert_eq!(row_data(&batches[0]), b"patch-1".to_vec()); + assert_eq!(row_data(&batches[1]), b"patch-2".to_vec()); + } + + #[tokio::test] + async fn client_version_before_updated_base_returns_full_response() { + let (server, _temp) = create_disk_test_server().await; + let meta = put_and_patch(&server).await; + let key = ObjectKey::try_from(meta.key.as_str()).unwrap(); + let updated_meta = server + .cache + .put(key, Object::new(0, b"updated-base".to_vec())) + .await + .unwrap(); + + let batches = get_batches(&server, &format!("{}:1", meta.key)).await; + + assert_eq!(updated_meta.version, 4); + assert_eq!(batches.len(), 1); + assert_eq!(row_kind(&batches[0]), ObjectResponseKind::Base.as_str()); + assert_eq!(row_version(&batches[0]), 4); + assert_eq!(row_data(&batches[0]), b"updated-base".to_vec()); + } + + #[tokio::test] + async fn client_version_ahead_of_server_returns_full_response() { + let (server, _temp) = create_disk_test_server().await; + let meta = put_and_patch(&server).await; + + let batches = get_batches(&server, &format!("{}:99", meta.key)).await; + + assert_eq!(batches.len(), 3); + assert_eq!(row_kind(&batches[0]), ObjectResponseKind::Base.as_str()); + assert_eq!(row_kind(&batches[1]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_kind(&batches[2]), ObjectResponseKind::Patch.as_str()); + assert_eq!(row_version(&batches[0]), 1); + assert_eq!(row_version(&batches[1]), 2); + assert_eq!(row_version(&batches[2]), 3); + } + + #[tokio::test] + async fn matching_client_version_returns_empty_response() { + let (server, _temp) = create_disk_test_server().await; + let meta = put_and_patch(&server).await; + + let batches = get_batches(&server, &format!("{}:{}", meta.key, meta.version)).await; + + assert!(batches.is_empty()); + } + } + mod flight_data_conversion { use super::*; diff --git a/object_cache/src/storage/disk.rs b/object_cache/src/storage/disk.rs index 0d9131a1..e574bcff 100644 --- a/object_cache/src/storage/disk.rs +++ b/object_cache/src/storage/disk.rs @@ -99,7 +99,7 @@ impl StorageEngine for DiskStorage { return Ok(None); } let base = load_object_from_file(&object_path)?; - let deltas = read_deltas_sync(&delta_dir)?; + let deltas = read_deltas_sync(&delta_dir, base.version)?; Ok(Some(Object::with_deltas(base.version, base.data, deltas))) }) .await @@ -261,7 +261,7 @@ impl StorageEngine for DiskStorage { let delta_dir = session_path.join(format!("{}.deltas", object_id)); let base = load_object_from_file(&object_path)?; - let deltas = read_deltas_sync(&delta_dir)?; + let deltas = read_deltas_sync(&delta_dir, base.version)?; let object = Object::with_deltas(base.version, base.data, deltas); results.push((key, object)); @@ -299,7 +299,7 @@ fn count_deltas_sync(delta_dir: &Path) -> u64 { .unwrap_or(0) } -fn read_deltas_sync(delta_dir: &Path) -> Result, FlameError> { +fn read_deltas_sync(delta_dir: &Path, base_version: u64) -> Result, FlameError> { if !delta_dir.exists() { return Ok(Vec::new()); } @@ -317,10 +317,18 @@ fn read_deltas_sync(delta_dir: &Path) -> Result, FlameError> { .unwrap_or(u64::MAX) }); - delta_files + let mut deltas: Vec = delta_files .into_par_iter() .map(|entry| load_object_from_file(&entry.path())) - .collect() + .collect::, _>>()?; + + for (idx, delta) in deltas.iter_mut().enumerate() { + if delta.version == 0 { + delta.version = base_version + idx as u64 + 1; + } + } + + Ok(deltas) } fn write_batch_to_file(path: &Path, batch: &RecordBatch) -> Result<(), FlameError> { @@ -442,15 +450,40 @@ mod tests { let object = Object::new(1, vec![1, 2, 3]); storage.write_object(&key, &object).await.unwrap(); - let delta = Object::new(0, vec![4, 5, 6]); + let delta = Object::new(2, vec![4, 5, 6]); let meta = storage.patch_object(&key, &delta).await.unwrap(); assert_eq!(meta.delta_count, 1); let loaded = storage.read_object(&key).await.unwrap().unwrap(); assert_eq!(loaded.deltas.len(), 1); + assert_eq!(loaded.deltas[0].version, 2); assert_eq!(loaded.deltas[0].data, vec![4, 5, 6]); } + #[tokio::test] + async fn test_disk_storage_synthesizes_old_zero_delta_versions() { + let temp_dir = tempdir().unwrap(); + let storage = DiskStorage::new(temp_dir.path().to_path_buf()).unwrap(); + + let key = test_key("test-app", "test-session", "obj1"); + let object = Object::new(5, vec![1, 2, 3]); + storage.write_object(&key, &object).await.unwrap(); + + storage + .patch_object(&key, &Object::new(0, vec![4])) + .await + .unwrap(); + storage + .patch_object(&key, &Object::new(0, vec![5])) + .await + .unwrap(); + + let loaded = storage.read_object(&key).await.unwrap().unwrap(); + assert_eq!(loaded.deltas.len(), 2); + assert_eq!(loaded.deltas[0].version, 6); + assert_eq!(loaded.deltas[1].version, 7); + } + #[tokio::test] async fn test_disk_storage_delete_objects() { let temp_dir = tempdir().unwrap(); diff --git a/sdk/python/src/flamepy/core/cache.py b/sdk/python/src/flamepy/core/cache.py index 551631e2..3ec834ab 100644 --- a/sdk/python/src/flamepy/core/cache.py +++ b/sdk/python/src/flamepy/core/cache.py @@ -15,7 +15,8 @@ import threading import uuid from collections import OrderedDict -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field +from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional @@ -53,6 +54,19 @@ Deserializer = Callable[[Any, List[Any]], Any] WILDCARD_SESSION = "*" +OBJECT_FIELD_VERSION = "version" +OBJECT_FIELD_DATA = "data" +OBJECT_RESPONSE_FIELD_KIND = "kind" + + +class FetchMode(str, Enum): + FULL = "full" + PATCHES = "patches" + + +class ObjectResponseKind(str, Enum): + BASE = "base" + PATCH = "patch" @dataclass @@ -79,16 +93,32 @@ def decode(cls, json_data: bytes) -> "ObjectRef": return cls(**data) +@dataclass +class Patch: + version: int + data: Any + + @dataclass class Object: - """Cached object with version and deserialized data. + """Cached object with base data, versioned patches, and materialized views. - Note: Stores deserialized data to avoid repeated deserialization. - Different deserializers on the same cached object are not supported. + The `data` field stores the base object to preserve the existing public + behavior where get_object(..., deserializer=None) returns only the base. """ version: int data: Any + patches: List[Patch] = field(default_factory=list) + materialized: Dict[Optional[int], Any] = field(default_factory=dict) + + +@dataclass +class FetchResult: + mode: FetchMode + version: int + base: Any = None + patches: List[Patch] = field(default_factory=list) # Client-side LRU cache with max size limit (O(1) operations using OrderedDict) @@ -131,6 +161,20 @@ def _cache_remove_prefix(prefix: str) -> None: _object_cache.pop(key, None) +def _materialize_object(obj: Object, deserializer: Optional[Deserializer] = None) -> Any: + materialized_key = None if deserializer is None else id(deserializer) + if materialized_key in obj.materialized: + return obj.materialized[materialized_key] + + if deserializer is None: + data = obj.data + else: + data = deserializer(obj.data, [patch.data for patch in obj.patches]) + + obj.materialized[materialized_key] = data + return data + + @dataclass(frozen=True) class ObjectKey: """Parsed object key: // @@ -350,8 +394,8 @@ def _serialize_object(obj: Any) -> pa.RecordBatch: schema = pa.schema( [ - pa.field("version", pa.uint64()), - pa.field("data", pa.binary()), + pa.field(OBJECT_FIELD_VERSION, pa.uint64()), + pa.field(OBJECT_FIELD_DATA, pa.binary()), ] ) @@ -366,7 +410,7 @@ def _deserialize_object(batch: pa.RecordBatch) -> Any: Automatically detects the serialization format from the type marker. """ - data_array = batch.column("data") + data_array = batch.column(OBJECT_FIELD_DATA) data_bytes = data_array[0].as_py() return _deserialize_object_data(data_bytes) @@ -632,25 +676,48 @@ def get_object(ref: ObjectRef, deserializer: Optional[Deserializer] = None) -> A cached_version = cached.version if cached else 0 logger.debug(f"get_object: key={ref.key}, cached_version={cached_version}") - result = _fetch_object_data(ref, cached_version, deserializer) + result = _fetch_object_data(ref, cached_version) if result is None: if cached_version > 0: cached = _cache_get(cache_key) if cached is not None: logger.debug(f"get_object: not_modified, returning cached for key={ref.key}") - return cached.data + return _materialize_object(cached, deserializer) logger.error(f"get_object: cache miss after not_modified! key={ref.key}, cached_version={cached_version}") raise ValueError(f"Object not found: {ref.key}") - data, version = result - _cache_put(cache_key, Object(version=version, data=data)) + if result.mode == FetchMode.FULL: + cached = Object( + version=result.version, + data=result.base, + patches=result.patches, + ) + _cache_put(cache_key, cached) + elif result.mode == FetchMode.PATCHES: + cached = _cache_get(cache_key) + if cached is None: + full_result = _fetch_object_data(ref, 0) + if full_result is None or full_result.mode != FetchMode.FULL: + raise ValueError(f"Object not found: {ref.key}") + cached = Object( + version=full_result.version, + data=full_result.base, + patches=full_result.patches, + ) + else: + cached.patches.extend(result.patches) + cached.version = result.version + cached.materialized.clear() + _cache_put(cache_key, cached) + else: + raise ValueError(f"Unexpected object fetch mode: {result.mode}") - logger.debug(f"get_object: key={ref.key}, version={version}") - return data + logger.debug(f"get_object: key={ref.key}, version={cached.version}") + return _materialize_object(cached, deserializer) -def _fetch_object_data(ref: ObjectRef, cached_version: int, deserializer: Optional[Deserializer] = None) -> Optional[tuple[Any, int]]: +def _fetch_object_data(ref: ObjectRef, cached_version: int) -> Optional[FetchResult]: tls_config = _get_cache_tls_config() client = _get_flight_client(ref.endpoint, tls_config) @@ -662,17 +729,61 @@ def _fetch_object_data(ref: ObjectRef, cached_version: int, deserializer: Option if table.num_rows == 0: return None - batches = table.to_batches() - base = _deserialize_object(batches[0]) - version = batches[0].column("version")[0].as_py() + if OBJECT_RESPONSE_FIELD_KIND not in table.column_names: + batches = table.to_batches() + base = _deserialize_object(batches[0]) + version = batches[0].column(OBJECT_FIELD_VERSION)[0].as_py() + patches = [ + Patch( + version=batch.column(OBJECT_FIELD_VERSION)[0].as_py(), + data=_deserialize_object(batch), + ) + for batch in batches[1:] + ] + return FetchResult( + mode=FetchMode.FULL, + version=max([version] + [patch.version for patch in patches]), + base=base, + patches=patches, + ) - if deserializer is not None: - deltas = [_deserialize_object(batch) for batch in batches[1:]] - data = deserializer(base, deltas) - else: - data = base + versions = table.column(OBJECT_FIELD_VERSION) + kinds = table.column(OBJECT_RESPONSE_FIELD_KIND) + data_values = table.column(OBJECT_FIELD_DATA) + rows: list[tuple[int, ObjectResponseKind, Any]] = [] + for idx in range(table.num_rows): + version = versions[idx].as_py() + kind_value = kinds[idx].as_py() + try: + kind = ObjectResponseKind(kind_value) + except ValueError as exc: + raise ValueError(f"Invalid object response row kind: {kind_value}") from exc + data = _deserialize_object_data(data_values[idx].as_py()) + rows.append((version, kind, data)) + + base_rows = [row for row in rows if row[1] == ObjectResponseKind.BASE] + patch_rows = [row for row in rows if row[1] == ObjectResponseKind.PATCH] + max_version = max(row[0] for row in rows) + patch_versions = [row[0] for row in patch_rows] + if patch_versions != sorted(patch_versions) or len(patch_versions) != len(set(patch_versions)): + raise ValueError("Patch response rows must have unique increasing versions") + + if base_rows: + first_base_index = next(idx for idx, row in enumerate(rows) if row[1] == ObjectResponseKind.BASE) + if len(base_rows) != 1 or first_base_index != 0: + raise ValueError("Full object response must start with exactly one base row") + patches = [Patch(version=row[0], data=row[2]) for row in patch_rows] + return FetchResult( + mode=FetchMode.FULL, + version=max_version, + base=base_rows[0][2], + patches=patches, + ) - return data, version + patches = [Patch(version=row[0], data=row[2]) for row in patch_rows] + if not patches: + return None + return FetchResult(mode=FetchMode.PATCHES, version=max_version, patches=patches) def update_object(ref: ObjectRef, new_obj: Any) -> "ObjectRef": @@ -911,7 +1022,7 @@ def download_object(ref: ObjectRef, dest_path: str) -> None: try: with open(dest_path, "wb") as f: for batch in reader: - data_array = batch.column("data") + data_array = batch.column(OBJECT_FIELD_DATA) for i in range(len(data_array)): chunk = data_array[i].as_py() if chunk: diff --git a/sdk/python/tests/test_cache.py b/sdk/python/tests/test_cache.py index 7b86b704..161f1378 100644 --- a/sdk/python/tests/test_cache.py +++ b/sdk/python/tests/test_cache.py @@ -11,9 +11,16 @@ _TYPE_ARROW_TABLE, _TYPE_CLOUDPICKLE, _TYPE_NUMPY, + OBJECT_FIELD_DATA, + OBJECT_FIELD_VERSION, + OBJECT_RESPONSE_FIELD_KIND, + FetchMode, + FetchResult, Object, ObjectKey, ObjectRef, + ObjectResponseKind, + Patch, _cache_lock, _deserialize_object, _deserialize_object_data, @@ -195,6 +202,42 @@ def teardown_method(self): with _cache_lock: _object_cache.clear() + def _response_table(self, rows): + return pa.table( + { + OBJECT_FIELD_VERSION: pa.array([row[0] for row in rows], type=pa.uint64()), + OBJECT_RESPONSE_FIELD_KIND: pa.array( + [row[1].value for row in rows], + type=pa.string(), + ), + OBJECT_FIELD_DATA: pa.array( + [_serialize_object_data(row[2]) for row in rows], + type=pa.binary(), + ), + } + ) + + def _patch_fetch_client(self, monkeypatch, table): + from flamepy.core import cache as cache_module + + class FakeReader: + def read_all(self): + return table + + class FakeClient: + def do_get(self, ticket): + self.ticket = ticket + return FakeReader() + + fake_client = FakeClient() + monkeypatch.setattr(cache_module, "_get_cache_tls_config", lambda: None) + monkeypatch.setattr( + cache_module, + "_get_flight_client", + lambda endpoint, tls_config: fake_client, + ) + return fake_client + def test_cache_hit_returns_cached_data(self, monkeypatch): from flamepy.core import cache as cache_module @@ -207,7 +250,7 @@ def test_cache_hit_returns_cached_data(self, monkeypatch): call_count = {"server": 0} - def mock_fetch_object_data(ref, cached_version, deserializer=None): + def mock_fetch_object_data(ref, cached_version): call_count["server"] += 1 return None @@ -224,8 +267,8 @@ def test_cache_miss_fetches_from_server(self, monkeypatch): server_data = {"from": "server"} - def mock_fetch_object_data(ref, cached_version, deserializer=None): - return server_data, 1 + def mock_fetch_object_data(ref, cached_version): + return FetchResult(mode=FetchMode.FULL, version=1, base=server_data) monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) @@ -252,8 +295,8 @@ def test_version_mismatch_triggers_download(self, monkeypatch): new_data = {"new": "data"} - def mock_fetch_object_data(ref, cached_version, deserializer=None): - return new_data, 2 + def mock_fetch_object_data(ref, cached_version): + return FetchResult(mode=FetchMode.FULL, version=2, base=new_data) monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) @@ -265,6 +308,118 @@ def mock_fetch_object_data(ref, cached_version, deserializer=None): assert _object_cache[cache_key].version == 2 assert _object_cache[cache_key].data == new_data + def test_patch_only_response_appends_to_cached_data(self, monkeypatch): + from flamepy.core import cache as cache_module + + cache_key = ("grpc://host:9090", "app/session/obj-patch") + cached_obj = Object(version=1, data=[1]) + + with _cache_lock: + _object_cache[cache_key] = cached_obj + + def mock_fetch_object_data(ref, cached_version): + assert cached_version == 1 + return FetchResult( + mode=FetchMode.PATCHES, + version=3, + patches=[ + Patch(version=2, data=[2]), + Patch(version=3, data=[3]), + ], + ) + + monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) + + def merge_lists(base_data, deltas): + result = list(base_data) + for delta in deltas: + result.extend(delta) + return result + + ref = ObjectRef(endpoint="grpc://host:9090", key="app/session/obj-patch", version=1) + result = cache_module.get_object(ref, deserializer=merge_lists) + + assert result == [1, 2, 3] + with _cache_lock: + cached = _object_cache[cache_key] + assert cached.version == 3 + assert [patch.version for patch in cached.patches] == [2, 3] + + def test_patch_only_response_without_cache_falls_back_to_full_fetch(self, monkeypatch): + from flamepy.core import cache as cache_module + + calls = [] + + def mock_fetch_object_data(ref, cached_version): + calls.append(cached_version) + if len(calls) == 1: + return FetchResult( + mode=FetchMode.PATCHES, + version=2, + patches=[Patch(version=2, data=[2])], + ) + return FetchResult( + mode=FetchMode.FULL, + version=2, + base=[1], + patches=[Patch(version=2, data=[2])], + ) + + monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) + + def merge_lists(base_data, deltas): + result = list(base_data) + for delta in deltas: + result.extend(delta) + return result + + ref = ObjectRef(endpoint="grpc://host:9090", key="app/session/obj-patch-miss", version=2) + result = cache_module.get_object(ref, deserializer=merge_lists) + + assert calls == [0, 0] + assert result == [1, 2] + with _cache_lock: + cached = _object_cache[("grpc://host:9090", "app/session/obj-patch-miss")] + assert cached.version == 2 + assert [patch.data for patch in cached.patches] == [[2]] + + def test_not_modified_reuses_materialized_deserializer_result(self, monkeypatch): + from flamepy.core import cache as cache_module + + cache_key = ("grpc://host:9090", "app/session/obj-not-modified") + cached_obj = Object( + version=2, + data=[1], + patches=[Patch(version=2, data=[2])], + ) + + with _cache_lock: + _object_cache[cache_key] = cached_obj + + fetch_calls = {"count": 0} + deserializer_calls = {"count": 0} + + def mock_fetch_object_data(ref, cached_version): + fetch_calls["count"] += 1 + assert cached_version == 2 + return None + + monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) + + def merge_lists(base_data, deltas): + deserializer_calls["count"] += 1 + result = list(base_data) + for delta in deltas: + result.extend(delta) + return result + + ref = ObjectRef(endpoint="grpc://host:9090", key="app/session/obj-not-modified", version=2) + + assert cache_module.get_object(ref, deserializer=merge_lists) == [1, 2] + assert cache_module.get_object(ref, deserializer=merge_lists) == [1, 2] + assert fetch_calls["count"] == 2 + assert deserializer_calls["count"] == 1 + def test_version_zero_bypasses_cache(self, monkeypatch): from flamepy.core import cache as cache_module @@ -277,9 +432,9 @@ def test_version_zero_bypasses_cache(self, monkeypatch): server_data = {"fresh": "data"} - def mock_fetch_object_data(ref, cached_version, deserializer=None): + def mock_fetch_object_data(ref, cached_version): assert cached_version == 0 - return server_data, 6 + return FetchResult(mode=FetchMode.FULL, version=6, base=server_data) monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) @@ -291,13 +446,16 @@ def mock_fetch_object_data(ref, cached_version, deserializer=None): def test_deserializer_combines_base_and_deltas(self, monkeypatch): from flamepy.core import cache as cache_module - def mock_fetch_object_data(ref, cached_version, deserializer=None): - base = [1, 2, 3] - delta1 = [4, 5] - delta2 = [6] - if deserializer is not None: - return deserializer(base, [delta1, delta2]), 1 - return base, 1 + def mock_fetch_object_data(ref, cached_version): + return FetchResult( + mode=FetchMode.FULL, + version=3, + base=[1, 2, 3], + patches=[ + Patch(version=2, data=[4, 5]), + Patch(version=3, data=[6]), + ], + ) monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) @@ -312,14 +470,55 @@ def merge_lists(base_data, deltas): assert result == [1, 2, 3, 4, 5, 6] + def test_fetch_object_data_parses_full_response_rows(self, monkeypatch): + from flamepy.core import cache as cache_module + + table = self._response_table( + [ + (1, ObjectResponseKind.BASE, [1]), + (2, ObjectResponseKind.PATCH, [2]), + (3, ObjectResponseKind.PATCH, [3]), + ] + ) + self._patch_fetch_client(monkeypatch, table) + + result = cache_module._fetch_object_data( + ObjectRef(endpoint="grpc://host:9090", key="app/session/obj6", version=1), + 0, + ) + + assert result.mode == FetchMode.FULL + assert result.version == 3 + assert result.base == [1] + assert [patch.data for patch in result.patches] == [[2], [3]] + + def test_fetch_object_data_rejects_base_after_patch(self, monkeypatch): + from flamepy.core import cache as cache_module + + table = self._response_table( + [ + (2, ObjectResponseKind.PATCH, [2]), + (1, ObjectResponseKind.BASE, [1]), + ] + ) + self._patch_fetch_client(monkeypatch, table) + + ref = ObjectRef(endpoint="grpc://host:9090", key="app/session/obj7", version=1) + try: + cache_module._fetch_object_data(ref, 1) + except ValueError as exc: + assert "base row" in str(exc) + else: + raise AssertionError("expected ValueError for malformed full response") + def test_thread_safety(self, monkeypatch): from flamepy.core import cache as cache_module results = [] errors = [] - def mock_fetch_object_data(ref, cached_version, deserializer=None): - return {"thread": ref.key}, 1 + def mock_fetch_object_data(ref, cached_version): + return FetchResult(mode=FetchMode.FULL, version=1, base={"thread": ref.key}) monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) From f848b086488d7e3d0e35bfa962c56907ee30a37a Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 12 May 2026 07:35:37 +0800 Subject: [PATCH 2/5] fix: address incremental get review comments --- object_cache/src/cache.rs | 36 +++++++++++++--------------- sdk/python/src/flamepy/core/cache.py | 4 ++-- sdk/python/tests/test_cache.py | 28 +++++++++++++--------- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/object_cache/src/cache.rs b/object_cache/src/cache.rs index e65a7338..4866f1f7 100644 --- a/object_cache/src/cache.rs +++ b/object_cache/src/cache.rs @@ -244,7 +244,6 @@ impl Object { .map(|delta| delta.version) .max() .unwrap_or(self.version) - .max(self.version) } } @@ -346,8 +345,8 @@ pub struct ObjectCache { objects: MutexPtr>, metadata: MutexPtr>, eviction_policy: EvictionPolicyPtr, - /// Per-key write locks to prevent concurrent PUT/PATCH race conditions - write_locks: MutexPtr>>>, + /// Per-key locks coordinate concurrent PUT/PATCH writes with GET snapshots. + key_locks: MutexPtr>>>, } impl ObjectCache { @@ -364,7 +363,7 @@ impl ObjectCache { objects: new_ptr(HashMap::new()), metadata: new_ptr(HashMap::new()), eviction_policy, - write_locks: new_ptr(HashMap::new()), + key_locks: new_ptr(HashMap::new()), }) } @@ -394,11 +393,11 @@ impl ObjectCache { Ok(()) } - fn get_write_lock(&self, key: &str) -> Result>, FlameError> { - let mut locks = lock_ptr!(self.write_locks)?; + fn get_key_lock(&self, key: &str) -> Result>, FlameError> { + let mut locks = lock_ptr!(self.key_locks)?; Ok(locks .entry(key.to_string()) - .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(()))) + .or_insert_with(|| Arc::new(tokio::sync::RwLock::new(()))) .clone()) } @@ -449,8 +448,8 @@ impl ObjectCache { let size = object.data.len() as u64; // Acquire per-key lock to prevent concurrent version increments - let write_lock = self.get_write_lock(&key_str)?; - let _guard = write_lock.lock().await; + let key_lock = self.get_key_lock(&key_str)?; + let _guard = key_lock.write().await; let version_from_memory = { let metadata = lock_ptr!(self.metadata)?; @@ -539,8 +538,8 @@ impl ObjectCache { })?; // Acquire per-key lock to prevent concurrent version increments - let write_lock = self.get_write_lock(&key_str)?; - let _guard = write_lock.lock().await; + let key_lock = self.get_key_lock(&key_str)?; + let _guard = key_lock.write().await; self.eviction_policy.on_access(&key_str); @@ -948,9 +947,9 @@ fn object_to_flight_data_vec(obj: &Object) -> Result, FlameError object_rows_to_flight_data_vec(rows) } -fn object_patches_to_flight_data_vec(patches: &[Object]) -> Result, FlameError> { +fn object_patches_to_flight_data_vec(patches: Vec<&Object>) -> Result, FlameError> { let rows = patches - .iter() + .into_iter() .map(|delta| (ObjectResponseKind::Patch, delta)) .collect(); object_rows_to_flight_data_vec(rows) @@ -1079,11 +1078,11 @@ impl FlightService for FlightCacheServer { let key = ObjectKey::try_from(key_str.as_str())?; - let write_lock = self + let key_lock = self .cache - .get_write_lock(&key_str) + .get_key_lock(&key_str) .map_err(|e| Status::internal(format!("Lock error: {}", e)))?; - let _guard = write_lock.lock().await; + let _guard = key_lock.read().await; let object = self.cache.get(&key).await?; let server_version = object.current_version(); @@ -1115,11 +1114,10 @@ impl FlightService for FlightCacheServer { server_version ); } else if client_version != 0 && object.version <= client_version { - let needed_patches: Vec = object + let needed_patches: Vec<&Object> = object .deltas .iter() .filter(|delta| delta.version > client_version) - .cloned() .collect(); let expected_patch_count = server_version.saturating_sub(client_version) as usize; let patch_suffix_is_contiguous = needed_patches.len() == expected_patch_count @@ -1136,7 +1134,7 @@ impl FlightService for FlightCacheServer { client_version, server_version ); - let flight_data_vec = object_patches_to_flight_data_vec(&needed_patches)?; + let flight_data_vec = object_patches_to_flight_data_vec(needed_patches)?; let stream = futures::stream::iter(flight_data_vec.into_iter().map(Ok)); return Ok(Response::new(Box::pin(stream))); } diff --git a/sdk/python/src/flamepy/core/cache.py b/sdk/python/src/flamepy/core/cache.py index 3ec834ab..0d103fe1 100644 --- a/sdk/python/src/flamepy/core/cache.py +++ b/sdk/python/src/flamepy/core/cache.py @@ -110,7 +110,7 @@ class Object: version: int data: Any patches: List[Patch] = field(default_factory=list) - materialized: Dict[Optional[int], Any] = field(default_factory=dict) + materialized: Dict[Optional[Deserializer], Any] = field(default_factory=dict) @dataclass @@ -162,7 +162,7 @@ def _cache_remove_prefix(prefix: str) -> None: def _materialize_object(obj: Object, deserializer: Optional[Deserializer] = None) -> Any: - materialized_key = None if deserializer is None else id(deserializer) + materialized_key = deserializer if materialized_key in obj.materialized: return obj.materialized[materialized_key] diff --git a/sdk/python/tests/test_cache.py b/sdk/python/tests/test_cache.py index 161f1378..3769a067 100644 --- a/sdk/python/tests/test_cache.py +++ b/sdk/python/tests/test_cache.py @@ -383,7 +383,7 @@ def merge_lists(base_data, deltas): assert cached.version == 2 assert [patch.data for patch in cached.patches] == [[2]] - def test_not_modified_reuses_materialized_deserializer_result(self, monkeypatch): + def test_not_modified_reuses_bound_method_materialized_result(self, monkeypatch): from flamepy.core import cache as cache_module cache_key = ("grpc://host:9090", "app/session/obj-not-modified") @@ -397,7 +397,6 @@ def test_not_modified_reuses_materialized_deserializer_result(self, monkeypatch) _object_cache[cache_key] = cached_obj fetch_calls = {"count": 0} - deserializer_calls = {"count": 0} def mock_fetch_object_data(ref, cached_version): fetch_calls["count"] += 1 @@ -406,19 +405,26 @@ def mock_fetch_object_data(ref, cached_version): monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) - def merge_lists(base_data, deltas): - deserializer_calls["count"] += 1 - result = list(base_data) - for delta in deltas: - result.extend(delta) - return result + class Merger: + def __init__(self): + self.calls = 0 + + def merge_lists(self, base_data, deltas): + self.calls += 1 + result = list(base_data) + for delta in deltas: + result.extend(delta) + return result + + merger = Merger() ref = ObjectRef(endpoint="grpc://host:9090", key="app/session/obj-not-modified", version=2) - assert cache_module.get_object(ref, deserializer=merge_lists) == [1, 2] - assert cache_module.get_object(ref, deserializer=merge_lists) == [1, 2] + assert cache_module.get_object(ref, deserializer=merger.merge_lists) == [1, 2] + assert cache_module.get_object(ref, deserializer=merger.merge_lists) == [1, 2] assert fetch_calls["count"] == 2 - assert deserializer_calls["count"] == 1 + assert merger.calls == 1 + assert len(cached_obj.materialized) == 1 def test_version_zero_bypasses_cache(self, monkeypatch): from flamepy.core import cache as cache_module From 0f27b8b15c4ab95711d3b069de0497b0b9457927 Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 12 May 2026 07:44:57 +0800 Subject: [PATCH 3/5] fix: harden incremental get client cache --- .../RFE445-incremental-object-get/FS.md | 8 +- sdk/python/src/flamepy/core/cache.py | 66 ++++++++++++-- sdk/python/tests/test_cache.py | 91 +++++++++++++++++++ 3 files changed, 153 insertions(+), 12 deletions(-) diff --git a/docs/designs/RFE445-incremental-object-get/FS.md b/docs/designs/RFE445-incremental-object-get/FS.md index 9433f91a..9bdec7ab 100644 --- a/docs/designs/RFE445-incremental-object-get/FS.md +++ b/docs/designs/RFE445-incremental-object-get/FS.md @@ -243,7 +243,7 @@ The important invariant is that patch-only rows are returned only when the clien **`sdk/python/src/flamepy/core/cache.py`** -- Replace the cached `Object(version, data)` shape with a richer cache entry: +- Replace the cached `Object(version, data)` shape with a richer `Object` cache entry: ```python @dataclass @@ -252,11 +252,11 @@ class Patch: data: Any @dataclass -class CachedObject: +class Object: version: int - base: Any + data: Any patches: list[Patch] - materialized: dict[int | None, Any] + materialized: dict[Any, Any] ``` - Request full or incremental data by choosing the effective request version. diff --git a/sdk/python/src/flamepy/core/cache.py b/sdk/python/src/flamepy/core/cache.py index 0d103fe1..2960e387 100644 --- a/sdk/python/src/flamepy/core/cache.py +++ b/sdk/python/src/flamepy/core/cache.py @@ -110,7 +110,7 @@ class Object: version: int data: Any patches: List[Patch] = field(default_factory=list) - materialized: Dict[Optional[Deserializer], Any] = field(default_factory=dict) + materialized: Dict[Any, Any] = field(default_factory=dict) @dataclass @@ -121,6 +121,19 @@ class FetchResult: patches: List[Patch] = field(default_factory=list) +class _IdentityKey: + __slots__ = ("value",) + + def __init__(self, value: Any): + self.value = value + + def __hash__(self) -> int: + return id(self.value) + + def __eq__(self, other: Any) -> bool: + return isinstance(other, _IdentityKey) and self.value is other.value + + # Client-side LRU cache with max size limit (O(1) operations using OrderedDict) _CACHE_MAX_SIZE = 1000 _object_cache: OrderedDict[tuple, Object] = OrderedDict() @@ -161,8 +174,19 @@ def _cache_remove_prefix(prefix: str) -> None: _object_cache.pop(key, None) +def _materialized_cache_key(deserializer: Optional[Deserializer]) -> Any: + if deserializer is None: + return None + + try: + hash(deserializer) + except TypeError: + return _IdentityKey(deserializer) + return deserializer + + def _materialize_object(obj: Object, deserializer: Optional[Deserializer] = None) -> Any: - materialized_key = deserializer + materialized_key = _materialized_cache_key(deserializer) if materialized_key in obj.materialized: return obj.materialized[materialized_key] @@ -175,6 +199,31 @@ def _materialize_object(obj: Object, deserializer: Optional[Deserializer] = None return data +def _cache_apply_patches( + key: tuple, + expected_version: int, + new_version: int, + patches: List[Patch], +) -> Optional[Object]: + """Apply patch rows only if the cache is still at the requested version.""" + with _cache_lock: + cached = _object_cache.get(key) + if cached is None: + return None + + if cached.version == expected_version: + if new_version <= cached.version: + return None + cached.patches.extend(patches) + cached.version = new_version + cached.materialized.clear() + elif cached.version < new_version: + return None + + _object_cache.move_to_end(key) + return cached + + @dataclass(frozen=True) class ObjectKey: """Parsed object key: // @@ -695,7 +744,12 @@ def get_object(ref: ObjectRef, deserializer: Optional[Deserializer] = None) -> A ) _cache_put(cache_key, cached) elif result.mode == FetchMode.PATCHES: - cached = _cache_get(cache_key) + cached = _cache_apply_patches( + cache_key, + expected_version=cached_version, + new_version=result.version, + patches=result.patches, + ) if cached is None: full_result = _fetch_object_data(ref, 0) if full_result is None or full_result.mode != FetchMode.FULL: @@ -705,11 +759,7 @@ def get_object(ref: ObjectRef, deserializer: Optional[Deserializer] = None) -> A data=full_result.base, patches=full_result.patches, ) - else: - cached.patches.extend(result.patches) - cached.version = result.version - cached.materialized.clear() - _cache_put(cache_key, cached) + _cache_put(cache_key, cached) else: raise ValueError(f"Unexpected object fetch mode: {result.mode}") diff --git a/sdk/python/tests/test_cache.py b/sdk/python/tests/test_cache.py index 3769a067..9db9d8ce 100644 --- a/sdk/python/tests/test_cache.py +++ b/sdk/python/tests/test_cache.py @@ -345,6 +345,57 @@ def merge_lists(base_data, deltas): assert cached.version == 3 assert [patch.version for patch in cached.patches] == [2, 3] + def test_concurrent_patch_only_fetches_do_not_duplicate_patches(self, monkeypatch): + from flamepy.core import cache as cache_module + + cache_key = ("grpc://host:9090", "app/session/obj-concurrent-patch") + cached_obj = Object(version=1, data=[1]) + + with _cache_lock: + _object_cache[cache_key] = cached_obj + + barrier = threading.Barrier(2) + + def mock_fetch_object_data(ref, cached_version): + assert cached_version == 1 + barrier.wait(timeout=5) + return FetchResult( + mode=FetchMode.PATCHES, + version=2, + patches=[Patch(version=2, data=[2])], + ) + + monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) + + def merge_lists(base_data, deltas): + result = list(base_data) + for delta in deltas: + result.extend(delta) + return result + + ref = ObjectRef(endpoint="grpc://host:9090", key="app/session/obj-concurrent-patch", version=1) + results = [] + errors = [] + + def worker(): + try: + results.append(cache_module.get_object(ref, deserializer=merge_lists)) + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert errors == [] + assert results == [[1, 2], [1, 2]] + with _cache_lock: + cached = _object_cache[cache_key] + assert cached.version == 2 + assert [patch.version for patch in cached.patches] == [2] + def test_patch_only_response_without_cache_falls_back_to_full_fetch(self, monkeypatch): from flamepy.core import cache as cache_module @@ -426,6 +477,46 @@ def merge_lists(self, base_data, deltas): assert merger.calls == 1 assert len(cached_obj.materialized) == 1 + def test_not_modified_accepts_unhashable_callable_deserializer(self, monkeypatch): + from flamepy.core import cache as cache_module + + cache_key = ("grpc://host:9090", "app/session/obj-unhashable") + cached_obj = Object( + version=2, + data=[1], + patches=[Patch(version=2, data=[2])], + ) + + with _cache_lock: + _object_cache[cache_key] = cached_obj + + def mock_fetch_object_data(ref, cached_version): + assert cached_version == 2 + return None + + monkeypatch.setattr(cache_module, "_fetch_object_data", mock_fetch_object_data) + + class UnhashableMerger: + def __init__(self): + self.calls = 0 + + def __eq__(self, other): + return self is other + + def __call__(self, base_data, deltas): + self.calls += 1 + result = list(base_data) + for delta in deltas: + result.extend(delta) + return result + + merger = UnhashableMerger() + ref = ObjectRef(endpoint="grpc://host:9090", key="app/session/obj-unhashable", version=2) + + assert cache_module.get_object(ref, deserializer=merger) == [1, 2] + assert cache_module.get_object(ref, deserializer=merger) == [1, 2] + assert merger.calls == 1 + def test_version_zero_bypasses_cache(self, monkeypatch): from flamepy.core import cache as cache_module From 3e27e2e764c3c060f56a4ae1baf9ae39e49a2a2b Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 12 May 2026 07:48:38 +0800 Subject: [PATCH 4/5] fix: align incremental get design and impl --- .../RFE445-incremental-object-get/FS.md | 11 +++--- e2e/tests/test_cache.py | 25 ++++++++++--- object_cache/src/cache.rs | 12 +++++-- object_cache/src/storage/disk.rs | 35 +++++++++++++++++++ 4 files changed, 71 insertions(+), 12 deletions(-) diff --git a/docs/designs/RFE445-incremental-object-get/FS.md b/docs/designs/RFE445-incremental-object-get/FS.md index 9bdec7ab..43db5c4b 100644 --- a/docs/designs/RFE445-incremental-object-get/FS.md +++ b/docs/designs/RFE445-incremental-object-get/FS.md @@ -406,7 +406,7 @@ No new auth surface. Ticket parsing must continue to validate `ObjectKey` before **Observability:** -Add debug logs and optional counters for: +Add debug logs. Optional cache counters can be added before formal replay-buffer benchmarking for: - `get_object_not_modified_total` - `get_object_patch_response_total` @@ -469,7 +469,7 @@ Version requirements: - Put base, patch twice, read from one Python process twice, verify the second read fetches only the second patch. - Put base, read, update base, read again, verify full response and correct final data. -- Replay buffer smoke test with metrics enabled verifies full/patch/not-modified counters are internally consistent. +- Replay buffer benchmark mode writes configured per-iteration metrics and preserves workload correctness. ## 4. Use Cases @@ -511,10 +511,13 @@ Expected outcome: same behavior as current RFE426 cache hits. **Metrics:** -Primary read-path metrics: +Primary read-path metrics available from the replay-buffer benchmark hooks: -- Total bytes downloaded by `get_object`. - `get_object` latency for `ReplayBuffer.state()` and `ReplayBuffer.sample()`. + +Optional cache-observability metrics to add before collecting formal numbers: + +- Total bytes downloaded by `get_object`. - Deserializer/materialization CPU time inside `_fetch()`. - Base size, patch count, and patch rows downloaded per read. - Number of full, patch-only, and not-modified responses. diff --git a/e2e/tests/test_cache.py b/e2e/tests/test_cache.py index ff56e210..36e1bfa4 100644 --- a/e2e/tests/test_cache.py +++ b/e2e/tests/test_cache.py @@ -168,24 +168,39 @@ def test_incremental_get_applies_remote_patch_only_response(): """Test a cached client applies only patches appended by another client.""" key_prefix = f"test-app/test-incremental-patch-{uuid.uuid4().hex[:8]}" base_data = {"items": ["base"]} - delta_data = {"items": ["patch-1"]} + delta_data_1 = {"items": ["patch-1"]} + delta_data_2 = {"items": ["patch-2"]} ref = put_object(key_prefix, base_data) assert get_object(ref, deserializer=_raw_deserializer) == {"base": base_data, "deltas": []} - patched_ref = _remote_patch_without_local_cache_invalidation(ref, delta_data) + patched_ref = _remote_patch_without_local_cache_invalidation(ref, delta_data_1) fetch_result = cache_module._fetch_object_data(ref, ref.version) assert fetch_result.mode == cache_module.FetchMode.PATCHES assert fetch_result.version == patched_ref.version - assert [patch.data for patch in fetch_result.patches] == [delta_data] + assert [patch.data for patch in fetch_result.patches] == [delta_data_1] result = get_object(ref, deserializer=_raw_deserializer) - assert result == {"base": base_data, "deltas": [delta_data]} + assert result == {"base": base_data, "deltas": [delta_data_1]} cached = _cached_object(ref) assert cached.version == patched_ref.version - assert [patch.data for patch in cached.patches] == [delta_data] + assert [patch.data for patch in cached.patches] == [delta_data_1] + + patched_ref_2 = _remote_patch_without_local_cache_invalidation(ref, delta_data_2) + second_fetch_result = cache_module._fetch_object_data(ref, cached.version) + + assert second_fetch_result.mode == cache_module.FetchMode.PATCHES + assert second_fetch_result.version == patched_ref_2.version + assert [patch.data for patch in second_fetch_result.patches] == [delta_data_2] + + result = get_object(ref, deserializer=_raw_deserializer) + assert result == {"base": base_data, "deltas": [delta_data_1, delta_data_2]} + + cached = _cached_object(ref) + assert cached.version == patched_ref_2.version + assert [patch.data for patch in cached.patches] == [delta_data_1, delta_data_2] def test_version_zero_forces_full_response_with_cached_object(): diff --git a/object_cache/src/cache.rs b/object_cache/src/cache.rs index 4866f1f7..ce622171 100644 --- a/object_cache/src/cache.rs +++ b/object_cache/src/cache.rs @@ -241,9 +241,7 @@ impl Object { pub fn current_version(&self) -> u64 { self.deltas .iter() - .map(|delta| delta.version) - .max() - .unwrap_or(self.version) + .fold(self.version, |current, delta| current.max(delta.version)) } } @@ -1498,6 +1496,14 @@ mod tests { assert_eq!(obj.deltas[1].data, vec![6, 7]); } + #[test] + fn current_version_includes_base_version() { + let delta = Object::new(3, vec![4, 5]); + let obj = Object::with_deltas(5, vec![1, 2, 3], vec![delta]); + + assert_eq!(obj.current_version(), 5); + } + #[test] fn object_clone_works() { let obj = Object::new(42, vec![10, 20, 30]); diff --git a/object_cache/src/storage/disk.rs b/object_cache/src/storage/disk.rs index e574bcff..d282a84c 100644 --- a/object_cache/src/storage/disk.rs +++ b/object_cache/src/storage/disk.rs @@ -327,10 +327,25 @@ fn read_deltas_sync(delta_dir: &Path, base_version: u64) -> Result, delta.version = base_version + idx as u64 + 1; } } + validate_delta_versions(&deltas, base_version)?; Ok(deltas) } +fn validate_delta_versions(deltas: &[Object], base_version: u64) -> Result<(), FlameError> { + let mut previous_version = base_version; + for delta in deltas { + if delta.version <= previous_version { + return Err(FlameError::InvalidState(format!( + "Patch versions must be strictly increasing after base version {}; found {} after {}", + base_version, delta.version, previous_version + ))); + } + previous_version = delta.version; + } + Ok(()) +} + fn write_batch_to_file(path: &Path, batch: &RecordBatch) -> Result<(), FlameError> { let file = fs::File::create(path)?; write_batch_to_writer(file, batch) @@ -484,6 +499,26 @@ mod tests { assert_eq!(loaded.deltas[1].version, 7); } + #[tokio::test] + async fn test_disk_storage_rejects_non_monotonic_delta_versions() { + let temp_dir = tempdir().unwrap(); + let storage = DiskStorage::new(temp_dir.path().to_path_buf()).unwrap(); + + let key = test_key("test-app", "test-session", "obj1"); + storage + .write_object(&key, &Object::new(5, vec![1, 2, 3])) + .await + .unwrap(); + storage + .patch_object(&key, &Object::new(4, vec![4])) + .await + .unwrap(); + + let result = storage.read_object(&key).await; + + assert!(result.is_err()); + } + #[tokio::test] async fn test_disk_storage_delete_objects() { let temp_dir = tempdir().unwrap(); From 98050a30693c0bc3d625b609a6df664097166374 Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 12 May 2026 10:35:15 +0800 Subject: [PATCH 5/5] docs: add replay buffer perf comparison steps --- examples/rl/replay_buffer/README.md | 77 +++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/examples/rl/replay_buffer/README.md b/examples/rl/replay_buffer/README.md index 6fbb38bf..6733ffa3 100644 --- a/examples/rl/replay_buffer/README.md +++ b/examples/rl/replay_buffer/README.md @@ -56,10 +56,79 @@ uv run main.py --local |------|-------------|---------| | `--env` | Gymnasium environment | CartPole-v1 | | `--local` | Run without Flame cluster | Off | -| `--iterations` | Collection iterations | 10 | -| `--collections` | Collections per iteration | 4 | -| `--steps-per-collection` | Steps per collection task | 100 | -| `--batch-size` | Sample batch size | 32 | +| `--iterations` | Collection iterations | 50 | +| `--collections` | Collections per iteration | 20 | +| `--steps-per-collection` | Steps per collection task | 500 | +| `--batch-size` | Sample batch size | 64 | +| `--metrics-json` | Write distributed-mode metrics to a JSON file | Off | +| `--merge-every` | Merge replay-buffer patches every N iterations | 5 | +| `--no-merge` | Disable patch merging to stress patch-only reads | Off | +| `--force-full-get` | Force replay-buffer reads to request full objects with version 0 | Off | + +## Performance Comparison + +Run the baseline and incremental cases on the same cluster, code revision, and workload shape. The baseline forces every replay-buffer read to request version `0`, so the cache returns the full base object plus all patches. The incremental case uses the cached nonzero version after the first read, so the cache can return only new patches when the base is still valid. + +```shell +docker compose exec -it flame-console /bin/bash +cd /opt/examples/rl/replay_buffer +mkdir -p /tmp/replay-buffer-metrics + +uv run main.py \ + --force-full-get \ + --metrics-json /tmp/replay-buffer-metrics/full.json \ + --iterations 50 \ + --collections 20 \ + --steps-per-collection 500 \ + --batch-size 64 + +uv run main.py \ + --metrics-json /tmp/replay-buffer-metrics/incremental.json \ + --iterations 50 \ + --collections 20 \ + --steps-per-collection 500 \ + --batch-size 64 +``` + +Compare total throughput and read-path latency: + +```shell +python - <<'PY' +import json +import statistics + +full = json.load(open("/tmp/replay-buffer-metrics/full.json")) +incremental = json.load(open("/tmp/replay-buffer-metrics/incremental.json")) + + +def median_metric(report, key): + values = [row[key] for row in report["iterations"] if row[key] > 0] + return statistics.median(values) if values else 0.0 + + +def pct_improvement(old, new): + return 0.0 if old == 0 else (old - new) / old * 100.0 + + +for label, report in [("full", full), ("incremental", incremental)]: + print( + f"{label:12} throughput={report['summary']['throughput']:.1f}/s " + f"median_state={median_metric(report, 'state_secs'):.4f}s " + f"median_sample={median_metric(report, 'sample_secs'):.4f}s" + ) + +print( + "state improvement: " + f"{pct_improvement(median_metric(full, 'state_secs'), median_metric(incremental, 'state_secs')):.1f}%" +) +print( + "sample improvement: " + f"{pct_improvement(median_metric(full, 'sample_secs'), median_metric(incremental, 'sample_secs')):.1f}%" +) +PY +``` + +For a stronger patch-only signal, repeat both runs with `--no-merge`. That keeps a long patch history and should make the forced-full baseline pay for old patches on every read, while the incremental run downloads only the patch suffix after the cached version. ## Example Output