Skip to content

[Data] speedup checkpoint filter 5x#60002

Closed
wxwmd wants to merge 1 commit intoray-project:masterfrom
wxwmd:speedup_ckpt_filter
Closed

[Data] speedup checkpoint filter 5x#60002
wxwmd wants to merge 1 commit intoray-project:masterfrom
wxwmd:speedup_ckpt_filter

Conversation

@wxwmd
Copy link
Copy Markdown
Contributor

@wxwmd wxwmd commented Jan 9, 2026

Modification

I'm using Ray Data's checkpoint. My data has 115 million records, with primary key {"id": str}. When I use Checkpoint to filter the input blocks, it takes several hours.

I checked the performance bottleneck and found it occurs in the filter_with_ckpt_chunk function in checkpoint_filter.py. I add some logs:

# Get all chunks of the checkpointed ID column.
ckpt_chunks = checkpointed_ids[self.id_column].chunks
# Convert the block's ID column to a numpy array for fast processing.
block_ids = block[self.id_column].to_numpy()

def filter_with_ckpt_chunk(ckpt_chunk: pyarrow.ChunkedArray) -> numpy.ndarray:
    t1 = time.time()
    ckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False)
    print(f"ckpt_ids to numpy cost time {time.time()-t1}s")
   
    ...
    t2 = time.time()
    sorted_indices = numpy.searchsorted(ckpt_ids, block_ids)
    print(f"searchsorted costs {time.time()-t2}s")

the ckpt_chunk has shape (115022113), and block_ids has shape (14534). I got:

ckpt_ids to numpy cost time: 6.057122468948364s
searchsorted costs 0.11587834358215332s

We can see from the perf test that:

  1. ckpt_chunks has only one chunk because we has combined chunks _combine_chunks
  2. the ckpt_chunk is a very large chunk that holds 115 millon ids, convert it from pyarrow to numpy will costs 6s
  3. For every input block, ckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False) is executed once, causing a large time overhead.

This PR obtains the ckpt_id numpy array in advance, avoiding multiple calls. In my tests, this can reduce the filtering time from 5 hours to 40 minutes.

Notes:

In this PR, each read task needs to read the ckpt_ids(numpy.ndarray) from the object store, rather than Arrow format. This increases I/O and memory overhead because Arrow arrays usually costs less space. In my experiment, the pyarrow array(115 million rows, string-typed) used 1.7 GB of memory, while the numpy array used 9 GB. However, I this this memory overhead is acceptable because of the performance improvement.

@wxwmd wxwmd requested a review from a team as a code owner January 9, 2026 09:15
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for checkpoint filtering by converting checkpointed IDs to a NumPy array once, rather than for every block. The changes are well-implemented and consistent across the modified files. My review includes a couple of suggestions to enhance code clarity and maintainability.

Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment on lines +689 to +695
combined_ckpt_block = transform_pyarrow.combine_chunks(pyarrow_checkpointed_ids)

combine_ckpt_chunks = combined_ckpt_block[ID_COL].chunks
assert len(combine_ckpt_chunks) == 1
# Convert checkpoint chunk to numpy for fast search.
# Use internal helper function for consistency and robustness (handles null-typed arrays, etc.)
ckpt_ids = transform_pyarrow.to_numpy(combine_ckpt_chunks[0], zero_copy_only=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for converting a pyarrow Table to a numpy array of IDs is duplicated from _combine_chunks in checkpoint_filter.py. To improve maintainability, consider extracting this logic into a non-remote helper function in checkpoint_filter.py and calling it from both _combine_chunks and this test. This would avoid having to update the logic in two places if it ever changes.

@wxwmd wxwmd changed the title [Data] speedup ckpt filter 5x [Data] speedup checkpoint filter 5x Jan 9, 2026
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
@ray-gardener ray-gardener bot added data Ray Data-related issues community-contribution Contributed by the community labels Jan 9, 2026
@owenowenisme owenowenisme self-assigned this Jan 10, 2026
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment thread python/ray/data/checkpoint/util.py Outdated
@wingkitlee0
Copy link
Copy Markdown
Contributor

This is nice. Some optimizations can be considered for future PRs:

  • it may be worth sorting the block_ids when performing searchsorted(checkpoint_ids, block_ids). There are some numpy internal optimization. We may want to use the original order for output tho.
  • The industry-standard sortedcontainers library uses a list of list (i.e., chunking). We may be able to do something similar: chunking the long array into multiple shorter ones (<1M elements), so that they all fit in cache (individually).
  • related to the second point, partitioning may help to avoid repartition(1) when loading the checkpoint (I haven't read thru how the checkpoint is constructed yet, but repartition(1) seems heavy if the pipeline almost finishes..)

filtering time from 5 hours to 40 minutes.

Just to understand better, this is the total time spent in the filter function, right?

@wxwmd
Copy link
Copy Markdown
Contributor Author

wxwmd commented Jan 12, 2026

This is nice. Some optimizations can be considered for future PRs:

  • it may be worth sorting the block_ids when performing searchsorted(checkpoint_ids, block_ids). There are some numpy internal optimization. We may want to use the original order for output tho.
  • The industry-standard sortedcontainers library uses a list of list (i.e., chunking). We may be able to do something similar: chunking the long array into multiple shorter ones (<1M elements), so that they all fit in cache (individually).
  • related to the second point, partitioning may help to avoid repartition(1) when loading the checkpoint (I haven't read thru how the checkpoint is constructed yet, but repartition(1) seems heavy if the pipeline almost finishes..)

filtering time from 5 hours to 40 minutes.

Just to understand better, this is the total time spent in the filter function, right?

@wingkitlee0 yes, this is the total time spent in the filter function. This PR addresses the time overhead caused by repeated copies from pyarrow->numpy. After this, I believe the points you mentioned can further improve performance. I'm interested in implementing them.

Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
@wxwmd wxwmd force-pushed the speedup_ckpt_filter branch from c6d23db to 0fb1a5d Compare January 12, 2026 07:17
Comment thread python/ray/data/checkpoint/checkpoint_filter.py
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment thread python/ray/data/_internal/planner/planner.py Outdated
Comment thread python/ray/data/_internal/planner/planner.py Outdated
@wxwmd wxwmd force-pushed the speedup_ckpt_filter branch from 8282751 to 8be95cf Compare January 12, 2026 12:14
Comment thread python/ray/data/checkpoint/checkpoint_filter.py
Comment thread python/ray/data/checkpoint/load_checkpoint_callback.py Outdated
@wxwmd
Copy link
Copy Markdown
Contributor Author

wxwmd commented Jan 13, 2026

seems kind of messy, i will split this into 3 pr

@wxwmd wxwmd force-pushed the speedup_ckpt_filter branch 2 times, most recently from c4eab5f to 0e213dc Compare January 13, 2026 11:05
@wxwmd wxwmd marked this pull request as draft January 14, 2026 07:35
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>

keep this pr simple
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
@wxwmd
Copy link
Copy Markdown
Contributor Author

wxwmd commented Jan 19, 2026

moved to #60294

richardliaw pushed a commit that referenced this pull request Apr 16, 2026
> source code for issue
https://github.com/issues/created?issue=ray-project%7Cray%7C60200

### Current checkpoint:

<img width="2796" height="1198" alt="Image"
src="https://github.com/user-attachments/assets/528ad72b-6975-4e96-8f01-39e373990647"
/>

The current implementation has two issues:
1. Each ReadTask copies an Arrow-typed checkpoint_id array and then
converts it into a Numpy-typed array. This step is very
time-consuming(see [previous
testing](#60002)) The most
time-consuming operation is repeated in every ReadTask.
2. Each ReadTask holds a copy of the checkpoint_id array, resulting in
high memory usage of the cluster.


### Improved Checkpoint (Initial design, single actor):
Maintain a global `checkpoint_filter` actor that holds the
`checkpoint_ids` array; this actor is responsible for filtering all
input blocks.

<img width="2096" height="1278" alt="Image"
src="https://github.com/user-attachments/assets/b9956eff-c807-45c4-bc4c-f0497974370d"
/>

There are two advantages to this approach:
1. The most time-consuming operation: the conversion from Arrow-typed
array to Numpy-typed array is performed only once.
2. Reduced memory usage: Each read task no longer needs to hold a large
array; only the `checkpoint_filter `actor holds it.

### Performance test
test code:
```
import shutil
from typing import Dict
import os
import time

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import ray
from ray.data.checkpoint import CheckpointConfig

INPUT_PATH="/tmp/ray_test/input/"
OUTPUT_PATH="/tmp/ray_test/output/"
CKPT_PATH="/tmp/ray_test/ckpt/"


class Qwen3ASRPredictor:
    def __init__(self):
        print("download ckpt")

    def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        return batch_input

def setup():
    if os.path.exists(INPUT_PATH):
        shutil.rmtree(INPUT_PATH)
    if os.path.exists(CKPT_PATH):
        shutil.rmtree(CKPT_PATH)
    if os.path.exists(OUTPUT_PATH):
        shutil.rmtree(OUTPUT_PATH)

    # generate input data
    if not os.path.exists(INPUT_PATH):
        os.makedirs(INPUT_PATH)
    for i in range(10000):
        ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)]
        df = pd.DataFrame({'id': ids})
        table = pa.Table.from_pandas(df)
        pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet"))

    # generate checkpoint
    if not os.path.exists(CKPT_PATH):
        os.makedirs(CKPT_PATH)
    ids = [str(i) for i in range(0, 80_000_000)]
    df = pd.DataFrame({'id': ids})
    table = pa.Table.from_pandas(df)
    pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet"))



if __name__ == "__main__":
    ray.init()

    setup()

    ctx = ray.data.DataContext.get_current()
    ctx.checkpoint_config = CheckpointConfig(
        id_column="id",
        checkpoint_path=CKPT_PATH,
        delete_checkpoint_on_success=False,
    )

    start_time = time.time()

    input = ray.data.read_parquet(
        INPUT_PATH,
        parallelism=1000,
        # memory=8 * 1024 **3 # set for origin ray to avoid oom
    )

    pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000)

    pred.write_parquet(OUTPUT_PATH)

    end_time = time.time()

    print(f"costs: {end_time - start_time}s")

    # check result
    result_ds = ray.data.read_parquet(OUTPUT_PATH)
    assert result_ds.count() == 20_000_000
```

node: 16 cores with 64GB memory (make sure you have memory at least 16GB
to avoid oom)

#### origin ray:
```
pip install ray==2.54.0
python test.py
```
#### Speedup:
```
pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
python test.py
```

#### Test Result
origin: 680s
speedup: 190s
You can see that even the end2end running time of the task has been
accelerated by 3.6 times.

#### Memory
If we delete this row:
```
memory=8 * 1024 **3 # set for origin ray to avoid oom
```
original ray will oom, the fixed ray passed. This demonstrates that this
PR has enhanced the stability.


----
### Updated 20260225 (ActorPool)
As @owenowenisme methoned, if filtering is performed by a single actor,
the single actor could be the bottleneck. Therefore, I extended a single
Actor into an ActorPool. For more details, please refer to the link.
#60294 (comment)


![20260223162250](https://github.com/user-attachments/assets/9bd1067f-f2a8-47dd-8f99-e232be64155e)

---------

Co-authored-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
Co-authored-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community data Ray Data-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants