# Batch Inference with Ray Data
© 2025, Anyscale. All Rights Reserved

💻 **Launch Locally**: You can run this notebook locally.

🚀 **Launch on Cloud**: Think about running this notebook on a Ray Cluster (Click [here](http://console.anyscale.com/register) to easily start a Ray cluster on Anyscale)

This example shows how to do batch inference with Ray Data.

Batch inference with Ray Data enables you to efficiently generate predictions from machine learning models on large datasets by processing multiple data points at once. Instead of running inference on one row at a time, which can be slow and resource-inefficient, batch inference leverages vectorized computation and parallelism to maximize throughput. This is especially useful when working with modern deep learning models, which are optimized for batch processing on CPUs, GPUs, or Apple Silicon devices.

The typical workflow begins by loading your dataset—such as a public dataset from Hugging Face—into a Ray Dataset. Ray Data can automatically partition the data for parallel processing, or you can repartition it explicitly to control the number of data blocks. Once the data is loaded, you define a callable class (such as a text embedding model) that loads the machine learning model in its constructor and implements a `__call__` method to process each batch. Ray Data’s `map_batches` API is then used to apply this callable to each batch of data, with options to control concurrency and resource allocation (e.g., number of GPUs).

This approach allows you to spin up multiple concurrent model instances, each processing different batches of data in parallel. The result is a significant speedup in inference time, especially for large datasets. After inference, you can materialize the results, inspect the output, and shut down the Ray cluster to free up resources. Batch inference with Ray Data is scalable, flexible, and integrates seamlessly with modern ML workflows, making it a powerful tool for production and research environments alike.

### Outline
<div class="alert alert-block alert-info">
<b>In this notebook, we go through a typical ML batch inference workflow:</b>

<ul>
    <li>Architecture
    <li>Import Libraries
    <li>Load a public dataset from Hugging Face and move it into Ray Data object store.
    <li>Batch Inference Class
        - Create a Ray actor class to load a ML model. In this example, we use SentenceTransformer library from Hugging Face to load a sentence embedding model.
    <li>Create batches of data to do inference.
    <li>Deploying at Scale
    <li>Inference on the entire dataset
    <li>Out of memory errors
    <li>Summary
</ul>
</div>

## Architecture

![Architecture Diagram](https://lz-public-demo.s3.us-east-1.amazonaws.com/anyscale101/01_examples/01_Ray_Data_batch_inf_arch.svg?sanitize=true)

### Import Libraries

In [1]:
import ray
import torch
from typing import Dict
import numpy as np

In [2]:
from sentence_transformers import SentenceTransformer # huggingface sentence transformers
from datasets import load_dataset # huggingface datasets

## Load a dataset
Load a dataset from hugging face or local and convert into Ray Dataset. A Ray cluster automatically initialized on local or on Anyscale platform. You can also use **ray.init()** To explicitly create or connect to an existing Ray cluster.

https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html#ray.init

In [3]:
# load a Hugging Face dataset
hf_dataset = load_dataset("cardiffnlp/tweet_eval", "sentiment", split="train")
# Convert the Hugging Face dataset to a Ray Dataset
ds = ray.data.from_huggingface(hf_dataset).repartition(2) # repartition to 2 blocks for parallel processing. Not necessary if already partitioned due to the size of the dataset.

2025-07-11 06:47:51,776	INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


In [4]:
# dataset metadata
print(ds)

Repartition
+- Dataset(num_rows=45615, schema={text: string, label: int64})


In [5]:
# show the first 10 rows
# Each row has "text" and "label"
ds.show(10)

2025-07-11 06:49:24,021	INFO dataset.py:3046 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-07-11 06:49:24,031	INFO logging.py:295 -- Registered dataset logger for dataset dataset_2_0
2025-07-11 06:49:24,056	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_2_0. Full logs are in /tmp/ray/session_2025-07-11_06-47-50_390429_98374/logs/ray-data
2025-07-11 06:49:24,056	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_2_0: InputDataBuffer[Input] -> AllToAllOperator[Repartition] -> LimitOperator[limit=10]


Running 0: 0.00 row [00:00, ? row/s]

- Repartition 1: 0.00 row [00:00, ? row/s]

Split Repartition 2:   0%|                                                                                    …

- limit=10 3: 0.00 row [00:00, ? row/s]

2025-07-11 06:49:24,378	INFO streaming_executor.py:227 -- ✔️  Dataset dataset_2_0 execution finished in 0.32 seconds


{'text': '"QT @user In the original draft of the 7th book, Remus Lupin survived the Battle of Hogwarts. #HappyBirthdayRemusLupin"', 'label': 2}
{'text': '"Ben Smith / Smith (concussion) remains out of the lineup Thursday, Curtis #NHL #SJ"', 'label': 1}
{'text': 'Sorry bout the stream last night I crashed out but will be on tonight for sure. Then back to Minecraft in pc tomorrow night.', 'label': 1}
{'text': "Chase Headley's RBI double in the 8th inning off David Price snapped a Yankees streak of 33 consecutive scoreless innings against Blue Jays", 'label': 1}
{'text': '@user Alciato: Bee will invest 150 million in January, another 200 in the Summer and plans to bring Messi by 2017"', 'label': 2}
{'text': "@user LIT MY MUM 'Kerry the louboutins I wonder how many Willam owns!!! Look Kerry Warner Wednesday!'", 'label': 2}
{'text': '"\\"""" SOUL TRAIN\\"""" OCT 27 HALLOWEEN SPECIAL ft T.dot FINEST rocking the mic...CRAZY CACTUS NIGHT CLUB ..ADV ticket $10 wt out costume $15..."', 'label': 

## Batch Inference Class
Many machine learning models are optimized for processing a batch of inputs at once. When working with a large dataset, there could be many batches of data. Instead of loading machine learning models repeatedly to run each batch of data, you want to spin up a number of actor processes that are **initialized once** with your model **and reused** to process multiple batches. 

To implement this, you can use the `map_batches` API with a "Callable" class method that implements:

- `__init__`: Initialize any expensive state.
- `__call__`: Perform the stateful transformation.

In this example, a lightweight sentence transformer model, **all-MiniLM-L6-v2** is used to generate embeddings of text data.

In [6]:
# Create an Ray actor class to embed text using the SentenceTransformer model
class TextEmbedder:
    def __init__(self):
        # load a pretrained sentence transformer model
        model_name = "all-MiniLM-L6-v2"  # A popular, lightweight sentence transformer model
        self.model = SentenceTransformer(model_name) # automatically detects cuda, mps, cpu

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        sentences = batch["text"] # use the "text" column
        batch['embedding'] = self.model.encode(sentences) # create embedding
        return batch


## Create a batch data and call the model
Define a Ray Data map_batches function to embed text using the SentenceTransformer model. This function will be applied to each batch of data in the Ray Data dataset. It will take a batch of sentences, encode them into embeddings, and return the batch with the embeddings added.

Showcasing two options of to do batch inference based on if the ray cluster has have GPU nodes or if it has just CPU nodes. The second option also works on a local ray cluster on an Apple Silicon Mac with MPS.

In [7]:
# setting manually so that code works on ray clusters with both CPU or GPU workers, or on a local mac with MPS
worker_device = "cpu" # or "cuda" if you have a nvidia gpu on worker nodes
# batch_size should be set based on VRAM 
if worker_device == "cuda": # if you have a nvidia gpu on worker nodes
    # adjust batch_size based on the VRAM available on the GPU
    ds = ds.map_batches(TextEmbedder, num_gpus=1, concurrency=2, batch_size=64) # 2 nodes with 1 GPU each
else:
    ds = ds.map_batches(TextEmbedder, concurrency=2, batch_size=64) # either cpu or mps (on a mac)

### Deploying at scale
- The batch size for encoding can be adjusted based on the available memory and performance requirements.
- The `device` parameter ensures that the model runs on the correct device (CPU, GPU, or MPS).
- The `concurrency` parameter controls how many batches are processed in parallel. If there are 2 nodes with 1 GPU each or 1 node with 2 GPUs, then set concurrency = 2 and num_gpus=1.
- map_batches() is a lazy function and not executed until needed (example, using take or show).

Run inference on a batch of 128 rows. This will return a batch of 128 rows with the embeddings added to the caller's machine.

In [8]:
# Run inference on a batch of 128 rows for testing.
ds.take_batch(128)

2025-07-11 07:00:56,090	INFO logging.py:295 -- Registered dataset logger for dataset dataset_4_0
2025-07-11 07:00:56,094	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2025-07-11_06-47-50_390429_98374/logs/ray-data
2025-07-11 07:00:56,095	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> AllToAllOperator[Repartition] -> ActorPoolMapOperator[MapBatches(TextEmbedder)] -> LimitOperator[limit=128]


Running 0: 0.00 row [00:00, ? row/s]

2025-07-11 07:00:56,105	INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 2 (reason=scaling to min size, running=0, restarting=0, pending=0)


- Repartition 1: 0.00 row [00:00, ? row/s]

Split Repartition 2:   0%|                                                                                    …

- MapBatches(TextEmbedder) 3: 0.00 row [00:00, ? row/s]

- limit=128 4: 0.00 row [00:00, ? row/s]

2025-07-11 07:01:12,263	INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=consumed all inputs; running=1, restarting=0, pending=0)
2025-07-11 07:01:12,377	INFO streaming_executor.py:227 -- ✔️  Dataset dataset_4_0 execution finished in 16.28 seconds


{'text': array(['"QT @user In the original draft of the 7th book, Remus Lupin survived the Battle of Hogwarts. #HappyBirthdayRemusLupin"',
        '"Ben Smith / Smith (concussion) remains out of the lineup Thursday, Curtis #NHL #SJ"',
        'Sorry bout the stream last night I crashed out but will be on tonight for sure. Then back to Minecraft in pc tomorrow night.',
        "Chase Headley's RBI double in the 8th inning off David Price snapped a Yankees streak of 33 consecutive scoreless innings against Blue Jays",
        '@user Alciato: Bee will invest 150 million in January, another 200 in the Summer and plans to bring Messi by 2017"',
        "@user LIT MY MUM 'Kerry the louboutins I wonder how many Willam owns!!! Look Kerry Warner Wednesday!'",
        '"\\"""" SOUL TRAIN\\"""" OCT 27 HALLOWEEN SPECIAL ft T.dot FINEST rocking the mic...CRAZY CACTUS NIGHT CLUB ..ADV ticket $10 wt out costume $15..."',
        'So disappointed in wwe summerslam! I want to see john cena wins his 16t

## Run inference on the entire dataset
Execute and materialize this dataset into object store memory. This operation will trigger execution of the lazy transformations performed on this dataset. The embedding model 'TextEmbedder' in map_batches() is called on the entire dataset.

In [9]:
# Run inference on the entire dataset
# Note that this does not mutate the original Dataset.
materialized_ds = ds.materialize()

2025-07-11 07:05:06,928	INFO logging.py:295 -- Registered dataset logger for dataset dataset_5_0
2025-07-11 07:05:06,931	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_5_0. Full logs are in /tmp/ray/session_2025-07-11_06-47-50_390429_98374/logs/ray-data
2025-07-11 07:05:06,932	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_5_0: InputDataBuffer[Input] -> AllToAllOperator[Repartition] -> ActorPoolMapOperator[MapBatches(TextEmbedder)]


Running 0: 0.00 row [00:00, ? row/s]

2025-07-11 07:05:06,942	INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 2 (reason=scaling to min size, running=0, restarting=0, pending=0)


- Repartition 1: 0.00 row [00:00, ? row/s]

Split Repartition 2:   0%|                                                                                    …

- MapBatches(TextEmbedder) 3: 0.00 row [00:00, ? row/s]

2025-07-11 07:05:12,188	INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=consumed all inputs; running=1, restarting=0, pending=0)
2025-07-11 07:05:38,090	INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=consumed all inputs; running=0, restarting=0, pending=0)
2025-07-11 07:05:38,097	INFO streaming_executor.py:227 -- ✔️  Dataset dataset_5_0 execution finished in 31.17 seconds


In [10]:
# metadata after inference
print('** Original dataset:', ds)
print('\n** Materialized dataset:', materialized_ds)

** Original dataset: MapBatches(TextEmbedder)
+- Repartition
   +- Dataset(num_rows=45615, schema={text: string, label: int64})

** Materialized dataset: MaterializedDataset(
   num_blocks=2,
   num_rows=45615,
   schema={
      text: string,
      label: int64,
      embedding: numpy.ndarray(shape=(384,), dtype=float)
   }
)


In [11]:
# Show a few rows of the materialized dataset with embeddings
materialized_ds.show(3)

2025-07-11 07:06:33,183	INFO logging.py:295 -- Registered dataset logger for dataset dataset_7_0
2025-07-11 07:06:33,184	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_7_0. Full logs are in /tmp/ray/session_2025-07-11_06-47-50_390429_98374/logs/ray-data
2025-07-11 07:06:33,185	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_7_0: InputDataBuffer[Input] -> LimitOperator[limit=3]


Running 0: 0.00 row [00:00, ? row/s]

- limit=3 1: 0.00 row [00:00, ? row/s]

2025-07-11 07:06:33,232	INFO streaming_executor.py:227 -- ✔️  Dataset dataset_7_0 execution finished in 0.05 seconds


{'text': '"QT @user In the original draft of the 7th book, Remus Lupin survived the Battle of Hogwarts. #HappyBirthdayRemusLupin"', 'label': 2, 'embedding': array([-1.17218718e-01,  7.51002133e-02,  8.44237953e-03,  2.10537948e-02,
       -8.00926834e-02,  6.09376505e-02,  8.09841380e-02,  6.39745891e-02,
       -1.34842591e-02, -1.22891478e-02,  9.70130879e-03,  8.13238472e-02,
        1.59728657e-02, -6.84826868e-03, -7.90290013e-02, -2.23982316e-02,
       -5.93358018e-02,  4.27904241e-02, -5.25669474e-03, -4.10776436e-02,
       -3.37784477e-02, -2.12912727e-02,  1.09729558e-01,  2.08834168e-02,
        6.42482564e-02, -5.55025972e-02, -3.46165411e-02,  6.61124960e-02,
       -5.21334969e-02, -3.30999158e-02, -1.77784879e-02,  3.47602344e-03,
       -2.97606084e-02, -6.36240765e-02, -4.66033891e-02,  6.25401214e-02,
        2.89564747e-02, -5.32266051e-02,  5.21380231e-02, -1.88834351e-02,
       -3.40600796e-02,  1.02842702e-02, -2.32401919e-02,  6.98712543e-02,
        5.05110919

### Out of memory errors
GPU (or MPS or CPU) memory has to keep the machine learning model and the batch of data in memory during the inference. If the batch_size is too large, it can run out of memory and throw out of memory errors. In that case, reduce the batch_size.

### Shutdown Ray cluster

In [12]:
# avoids collisons with other notebooks running ray jobs on the same machine
ray.shutdown()

### Summary
This notebook demonstrates how to perform efficient batch inference on large datasets using Ray Data. It walks through loading a public dataset from Hugging Face, converting it into a Ray Dataset, and defining a callable class to load and apply a machine learning model (SentenceTransformer) for embedding text. The notebook shows how to use Ray Data’s `map_batches` API to process data in parallel batches, leveraging available CPUs or GPUs for high-throughput inference. It also covers best practices for scaling, handling memory constraints, and summarizes how Ray Data enables scalable, distributed batch inference for modern ML workflows.