Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
603 changes: 603 additions & 0 deletions docs/designs/RFE445-incremental-object-get/FS.md

Large diffs are not rendered by default.

110 changes: 110 additions & 0 deletions e2e/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import time
import uuid

import flamepy.core.cache as cache_module
import pyarrow.flight as flight
import pytest
from flamepy.core import FlameContext, ObjectRef, get_object, patch_object, put_object, update_object

Expand Down Expand Up @@ -98,6 +100,27 @@ def _raw_deserializer(base, deltas):
return {"base": base, "deltas": deltas}


def _remote_patch_without_local_cache_invalidation(ref: ObjectRef, delta):
"""Patch through Flight directly to emulate another client process."""
batch = cache_module._serialize_object(delta)
client = cache_module._get_flight_client(ref.endpoint, cache_module._get_cache_tls_config())
descriptor = flight.FlightDescriptor.for_command(f"PATCH:{ref.key}".encode())
return cache_module._do_put_remote(client, descriptor, batch)


def _remote_update_without_local_cache_invalidation(ref: ObjectRef, new_obj):
"""Update through Flight directly to emulate another client process."""
batch = cache_module._serialize_object(new_obj)
client = cache_module._get_flight_client(ref.endpoint, cache_module._get_cache_tls_config())
descriptor = flight.FlightDescriptor.for_path(ref.key)
return cache_module._do_put_remote(client, descriptor, batch)


def _cached_object(ref: ObjectRef):
with cache_module._cache_lock:
return cache_module._object_cache[(ref.endpoint, ref.key)]


def test_patch_single_delta():
"""Test patching an object with a single delta."""
key_prefix = "test-app/test-patch-001"
Expand Down Expand Up @@ -141,6 +164,93 @@ def test_patch_multiple_deltas():
assert result["deltas"][2] == delta3


def test_incremental_get_applies_remote_patch_only_response():
"""Test a cached client applies only patches appended by another client."""
key_prefix = f"test-app/test-incremental-patch-{uuid.uuid4().hex[:8]}"
base_data = {"items": ["base"]}
delta_data_1 = {"items": ["patch-1"]}
delta_data_2 = {"items": ["patch-2"]}

ref = put_object(key_prefix, base_data)
assert get_object(ref, deserializer=_raw_deserializer) == {"base": base_data, "deltas": []}

patched_ref = _remote_patch_without_local_cache_invalidation(ref, delta_data_1)
fetch_result = cache_module._fetch_object_data(ref, ref.version)

assert fetch_result.mode == cache_module.FetchMode.PATCHES
assert fetch_result.version == patched_ref.version
assert [patch.data for patch in fetch_result.patches] == [delta_data_1]

result = get_object(ref, deserializer=_raw_deserializer)
assert result == {"base": base_data, "deltas": [delta_data_1]}

cached = _cached_object(ref)
assert cached.version == patched_ref.version
assert [patch.data for patch in cached.patches] == [delta_data_1]

patched_ref_2 = _remote_patch_without_local_cache_invalidation(ref, delta_data_2)
second_fetch_result = cache_module._fetch_object_data(ref, cached.version)

assert second_fetch_result.mode == cache_module.FetchMode.PATCHES
assert second_fetch_result.version == patched_ref_2.version
assert [patch.data for patch in second_fetch_result.patches] == [delta_data_2]

result = get_object(ref, deserializer=_raw_deserializer)
assert result == {"base": base_data, "deltas": [delta_data_1, delta_data_2]}

cached = _cached_object(ref)
assert cached.version == patched_ref_2.version
assert [patch.data for patch in cached.patches] == [delta_data_1, delta_data_2]


def test_version_zero_forces_full_response_with_cached_object():
"""Test version=0 gets the full base plus patches even with a local cache."""
key_prefix = f"test-app/test-incremental-full-{uuid.uuid4().hex[:8]}"
base_data = {"items": ["base"]}
delta_data = {"items": ["patch-1"]}

ref = put_object(key_prefix, base_data)
assert get_object(ref, deserializer=_raw_deserializer) == {"base": base_data, "deltas": []}
patched_ref = _remote_patch_without_local_cache_invalidation(ref, delta_data)

forced_ref = ObjectRef(endpoint=ref.endpoint, key=ref.key, version=0)
fetch_result = cache_module._fetch_object_data(forced_ref, 0)

assert fetch_result.mode == cache_module.FetchMode.FULL
assert fetch_result.version == patched_ref.version
assert fetch_result.base == base_data
assert [patch.data for patch in fetch_result.patches] == [delta_data]

result = get_object(forced_ref, deserializer=_raw_deserializer)
assert result == {"base": base_data, "deltas": [delta_data]}


def test_incremental_get_falls_back_to_full_after_remote_update():
"""Test stale cached base is replaced by a full response after update."""
key_prefix = f"test-app/test-incremental-update-{uuid.uuid4().hex[:8]}"
base_data = {"version": 1}
updated_data = {"version": 2}

ref = put_object(key_prefix, base_data)
assert get_object(ref, deserializer=_raw_deserializer) == {"base": base_data, "deltas": []}

updated_ref = _remote_update_without_local_cache_invalidation(ref, updated_data)
fetch_result = cache_module._fetch_object_data(ref, ref.version)

assert fetch_result.mode == cache_module.FetchMode.FULL
assert fetch_result.version == updated_ref.version
assert fetch_result.base == updated_data
assert fetch_result.patches == []

result = get_object(ref, deserializer=_raw_deserializer)
assert result == {"base": updated_data, "deltas": []}

cached = _cached_object(ref)
assert cached.version == updated_ref.version
assert cached.data == updated_data
assert cached.patches == []


def test_patch_preserves_delta_order():
"""Test that deltas are returned in the order they were appended."""
key_prefix = "test-app/test-patch-003"
Expand Down
77 changes: 73 additions & 4 deletions examples/rl/replay_buffer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,79 @@ uv run main.py --local
|------|-------------|---------|
| `--env` | Gymnasium environment | CartPole-v1 |
| `--local` | Run without Flame cluster | Off |
| `--iterations` | Collection iterations | 10 |
| `--collections` | Collections per iteration | 4 |
| `--steps-per-collection` | Steps per collection task | 100 |
| `--batch-size` | Sample batch size | 32 |
| `--iterations` | Collection iterations | 50 |
| `--collections` | Collections per iteration | 20 |
| `--steps-per-collection` | Steps per collection task | 500 |
| `--batch-size` | Sample batch size | 64 |
| `--metrics-json` | Write distributed-mode metrics to a JSON file | Off |
| `--merge-every` | Merge replay-buffer patches every N iterations | 5 |
| `--no-merge` | Disable patch merging to stress patch-only reads | Off |
| `--force-full-get` | Force replay-buffer reads to request full objects with version 0 | Off |

## Performance Comparison

Run the baseline and incremental cases on the same cluster, code revision, and workload shape. The baseline forces every replay-buffer read to request version `0`, so the cache returns the full base object plus all patches. The incremental case uses the cached nonzero version after the first read, so the cache can return only new patches when the base is still valid.

```shell
docker compose exec -it flame-console /bin/bash
cd /opt/examples/rl/replay_buffer
mkdir -p /tmp/replay-buffer-metrics

uv run main.py \
--force-full-get \
--metrics-json /tmp/replay-buffer-metrics/full.json \
--iterations 50 \
--collections 20 \
--steps-per-collection 500 \
--batch-size 64

uv run main.py \
--metrics-json /tmp/replay-buffer-metrics/incremental.json \
--iterations 50 \
--collections 20 \
--steps-per-collection 500 \
--batch-size 64
```

Compare total throughput and read-path latency:

```shell
python - <<'PY'
import json
import statistics

full = json.load(open("/tmp/replay-buffer-metrics/full.json"))
incremental = json.load(open("/tmp/replay-buffer-metrics/incremental.json"))


def median_metric(report, key):
values = [row[key] for row in report["iterations"] if row[key] > 0]
return statistics.median(values) if values else 0.0


def pct_improvement(old, new):
return 0.0 if old == 0 else (old - new) / old * 100.0


for label, report in [("full", full), ("incremental", incremental)]:
print(
f"{label:12} throughput={report['summary']['throughput']:.1f}/s "
f"median_state={median_metric(report, 'state_secs'):.4f}s "
f"median_sample={median_metric(report, 'sample_secs'):.4f}s"
)

print(
"state improvement: "
f"{pct_improvement(median_metric(full, 'state_secs'), median_metric(incremental, 'state_secs')):.1f}%"
)
print(
"sample improvement: "
f"{pct_improvement(median_metric(full, 'sample_secs'), median_metric(incremental, 'sample_secs')):.1f}%"
)
PY
```

For a stronger patch-only signal, repeat both runs with `--no-merge`. That keeps a long patch history and should make the forced-full baseline pay for old patches on every read, while the incremental run downloads only the patch suffix after the cached version.

## Example Output

Expand Down
93 changes: 90 additions & 3 deletions examples/rl/replay_buffer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def run_distributed(
num_collections: int = 20,
steps_per_collection: int = 500,
batch_size: int = 64,
merge_every: int | None = 5,
metrics_json: str | None = None,
force_full_get: bool = False,
):
import json
import time

from flamepy.runner import Runner
Expand All @@ -31,30 +35,43 @@ def run_distributed(
print(f" Steps per collection: {steps_per_collection}")
print(f" Iterations: {num_iterations}")
print(f" Batch size: {batch_size}")
print(f" Merge every: {merge_every if merge_every else 'disabled'}")
print(f" Force full get: {force_full_get}")
print("\nStarting distributed collection...")

start_time = time.time()
metrics = []
total_added = 0

with Runner(f"replay-buffer-{env_name.lower()}") as rr:
buffer = ReplayBuffer(rr)
buffer = ReplayBuffer(rr, force_full_get=force_full_get)
buffer_svc = rr.service(buffer, autoscale=False, warmup=1)
collector = rr.service(Collector(env_name), autoscale=True)

for iteration in range(num_iterations):
iteration_start = time.time()
collect_futures = [
collector.collect(buffer, steps_per_collection)
for _ in range(num_collections)
]
collect_results = rr.get(collect_futures)
collect_elapsed = time.time() - iteration_start

if iteration % 5 == 4:
merge_elapsed = 0.0
if merge_every and iteration % merge_every == merge_every - 1:
merge_start = time.time()
buffer_svc.merge().wait()
merge_elapsed = time.time() - merge_start

state_start = time.time()
stats = buffer_svc.state().get()
state_elapsed = time.time() - state_start
total_size = stats["size"]
total_added = stats["total_added"]
total_episodes = sum(r["episode_count"] for r in collect_results)
avg_reward = sum(r["avg_reward"] * r["episode_count"] for r in collect_results) / max(1, total_episodes)
avg_reward = sum(
r["avg_reward"] * r["episode_count"] for r in collect_results
) / max(1, total_episodes)

print(
f"Iteration {iteration:2d} | "
Expand All @@ -63,10 +80,30 @@ def run_distributed(
f"Avg Reward: {avg_reward:7.1f}"
)

sample_elapsed = 0.0
sampled = 0
if total_size >= batch_size:
sample_start = time.time()
batch = buffer_svc.sample(batch_size).get()
sample_elapsed = time.time() - sample_start
sampled = len(batch)
print(f" | Sampled batch of {len(batch)} transitions")

metrics.append(
{
"iteration": iteration,
"collect_secs": collect_elapsed,
"merge_secs": merge_elapsed,
"state_secs": state_elapsed,
"sample_secs": sample_elapsed,
"buffer_size": total_size,
"total_added": total_added,
"total_episodes": total_episodes,
"avg_reward": avg_reward,
"sampled": sampled,
}
)

elapsed = time.time() - start_time
print("\n" + "=" * 60)
print("Collection Complete!")
Expand All @@ -75,6 +112,30 @@ def run_distributed(
print(f" Throughput: {total_added / elapsed:.1f} transitions/sec")
print("=" * 60)

if metrics_json:
with open(metrics_json, "w") as f:
json.dump(
{
"configuration": {
"env": env_name,
"iterations": num_iterations,
"collections": num_collections,
"steps_per_collection": steps_per_collection,
"batch_size": batch_size,
"merge_every": merge_every,
"force_full_get": force_full_get,
},
"summary": {
"total_time_secs": elapsed,
"total_transitions": total_added,
"throughput": total_added / elapsed,
},
"iterations": metrics,
},
f,
indent=2,
)


def run_local(
env_name: str = "CartPole-v1",
Expand Down Expand Up @@ -187,8 +248,31 @@ def main():
parser.add_argument(
"--batch-size", type=int, default=64, help="Batch size for sampling"
)
parser.add_argument(
"--metrics-json",
type=str,
default=None,
help="Write distributed-mode metrics to a JSON file",
)
parser.add_argument(
"--merge-every",
type=int,
default=5,
help="Merge replay-buffer patches every N iterations",
)
parser.add_argument(
"--no-merge",
action="store_true",
help="Disable replay-buffer patch merging",
)
parser.add_argument(
"--force-full-get",
action="store_true",
help="Force replay-buffer reads to request full objects with version 0",
)

args = parser.parse_args()
merge_every = None if args.no_merge else args.merge_every

if args.local:
run_local(
Expand All @@ -204,6 +288,9 @@ def main():
num_collections=args.collections,
steps_per_collection=args.steps_per_collection,
batch_size=args.batch_size,
merge_every=merge_every,
metrics_json=args.metrics_json,
force_full_get=args.force_full_get,
)


Expand Down
Loading
Loading