[Data] speedup checkpoint filter 5x#60002
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a significant performance optimization for checkpoint filtering by converting checkpointed IDs to a NumPy array once, rather than for every block. The changes are well-implemented and consistent across the modified files. My review includes a couple of suggestions to enhance code clarity and maintainability.
| combined_ckpt_block = transform_pyarrow.combine_chunks(pyarrow_checkpointed_ids) | ||
|
|
||
| combine_ckpt_chunks = combined_ckpt_block[ID_COL].chunks | ||
| assert len(combine_ckpt_chunks) == 1 | ||
| # Convert checkpoint chunk to numpy for fast search. | ||
| # Use internal helper function for consistency and robustness (handles null-typed arrays, etc.) | ||
| ckpt_ids = transform_pyarrow.to_numpy(combine_ckpt_chunks[0], zero_copy_only=False) |
There was a problem hiding this comment.
This logic for converting a pyarrow Table to a numpy array of IDs is duplicated from _combine_chunks in checkpoint_filter.py. To improve maintainability, consider extracting this logic into a non-remote helper function in checkpoint_filter.py and calling it from both _combine_chunks and this test. This would avoid having to update the logic in two places if it ever changes.
|
This is nice. Some optimizations can be considered for future PRs:
Just to understand better, this is the total time spent in the filter function, right? |
@wingkitlee0 yes, this is the total time spent in the filter function. This PR addresses the time overhead caused by repeated copies from |
c6d23db to
0fb1a5d
Compare
8282751 to
8be95cf
Compare
|
seems kind of messy, i will split this into 3 pr |
c4eab5f to
0e213dc
Compare
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com> keep this pr simple Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
0e213dc to
b0142c2
Compare
|
moved to #60294 |
> source code for issue https://github.com/issues/created?issue=ray-project%7Cray%7C60200 ### Current checkpoint: <img width="2796" height="1198" alt="Image" src="https://github.com/user-attachments/assets/528ad72b-6975-4e96-8f01-39e373990647" /> The current implementation has two issues: 1. Each ReadTask copies an Arrow-typed checkpoint_id array and then converts it into a Numpy-typed array. This step is very time-consuming(see [previous testing](#60002)) The most time-consuming operation is repeated in every ReadTask. 2. Each ReadTask holds a copy of the checkpoint_id array, resulting in high memory usage of the cluster. ### Improved Checkpoint (Initial design, single actor): Maintain a global `checkpoint_filter` actor that holds the `checkpoint_ids` array; this actor is responsible for filtering all input blocks. <img width="2096" height="1278" alt="Image" src="https://github.com/user-attachments/assets/b9956eff-c807-45c4-bc4c-f0497974370d" /> There are two advantages to this approach: 1. The most time-consuming operation: the conversion from Arrow-typed array to Numpy-typed array is performed only once. 2. Reduced memory usage: Each read task no longer needs to hold a large array; only the `checkpoint_filter `actor holds it. ### Performance test test code: ``` import shutil from typing import Dict import os import time import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import ray from ray.data.checkpoint import CheckpointConfig INPUT_PATH="/tmp/ray_test/input/" OUTPUT_PATH="/tmp/ray_test/output/" CKPT_PATH="/tmp/ray_test/ckpt/" class Qwen3ASRPredictor: def __init__(self): print("download ckpt") def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: return batch_input def setup(): if os.path.exists(INPUT_PATH): shutil.rmtree(INPUT_PATH) if os.path.exists(CKPT_PATH): shutil.rmtree(CKPT_PATH) if os.path.exists(OUTPUT_PATH): shutil.rmtree(OUTPUT_PATH) # generate input data if not os.path.exists(INPUT_PATH): os.makedirs(INPUT_PATH) for i in range(10000): ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)] df = pd.DataFrame({'id': ids}) table = pa.Table.from_pandas(df) pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet")) # generate checkpoint if not os.path.exists(CKPT_PATH): os.makedirs(CKPT_PATH) ids = [str(i) for i in range(0, 80_000_000)] df = pd.DataFrame({'id': ids}) table = pa.Table.from_pandas(df) pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet")) if __name__ == "__main__": ray.init() setup() ctx = ray.data.DataContext.get_current() ctx.checkpoint_config = CheckpointConfig( id_column="id", checkpoint_path=CKPT_PATH, delete_checkpoint_on_success=False, ) start_time = time.time() input = ray.data.read_parquet( INPUT_PATH, parallelism=1000, # memory=8 * 1024 **3 # set for origin ray to avoid oom ) pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000) pred.write_parquet(OUTPUT_PATH) end_time = time.time() print(f"costs: {end_time - start_time}s") # check result result_ds = ray.data.read_parquet(OUTPUT_PATH) assert result_ds.count() == 20_000_000 ``` node: 16 cores with 64GB memory (make sure you have memory at least 16GB to avoid oom) #### origin ray: ``` pip install ray==2.54.0 python test.py ``` #### Speedup: ``` pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl python test.py ``` #### Test Result origin: 680s speedup: 190s You can see that even the end2end running time of the task has been accelerated by 3.6 times. #### Memory If we delete this row: ``` memory=8 * 1024 **3 # set for origin ray to avoid oom ``` original ray will oom, the fixed ray passed. This demonstrates that this PR has enhanced the stability. ---- ### Updated 20260225 (ActorPool) As @owenowenisme methoned, if filtering is performed by a single actor, the single actor could be the bottleneck. Therefore, I extended a single Actor into an ActorPool. For more details, please refer to the link. #60294 (comment)  --------- Co-authored-by: xiaowen.wxw <wxw403883@alibaba-inc.com> Co-authored-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com>
Modification
I'm using Ray Data's checkpoint. My data has 115 million records, with primary key {"id": str}. When I use Checkpoint to filter the input blocks, it takes several hours.
I checked the performance bottleneck and found it occurs in the
filter_with_ckpt_chunkfunction in checkpoint_filter.py. I add some logs:the
ckpt_chunkhas shape (115022113), and block_ids has shape (14534). I got:We can see from the perf test that:
ckpt_chunkshas only one chunk because we has combined chunks _combine_chunksckpt_chunkis a very large chunk that holds 115 millon ids, convert it from pyarrow to numpy will costs 6sckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False)is executed once, causing a large time overhead.This PR obtains the
ckpt_idnumpy array in advance, avoiding multiple calls. In my tests, this can reduce the filtering time from 5 hours to 40 minutes.Notes:
In this PR, each read task needs to read the ckpt_ids(numpy.ndarray) from the object store, rather than Arrow format. This increases I/O and memory overhead because Arrow arrays usually costs less space. In my experiment, the pyarrow array(115 million rows, string-typed) used 1.7 GB of memory, while the numpy array used 9 GB. However, I this this memory overhead is acceptable because of the performance improvement.