## Create a batch data and call the model
Define a Ray Data map_batches function to embed text using the SentenceTransformer model. This function will be applied to each batch of data in the Ray Data dataset. It will take a batch of sentences, encode them into embeddings, and return the batch with the embeddings added.

Showcasing two options of to do batch inference based on if the ray cluster has have GPU nodes or if it has just CPU nodes. The second option also works on a local ray cluster on an Apple Silicon Mac with MPS.

In [7]:
# setting manually so that code works on ray clusters with both CPU or GPU workers, or on a local mac with MPS
worker_device = "cpu" # or "cuda" if you have a nvidia gpu on worker nodes
# batch_size should be set based on VRAM 
if worker_device == "cuda": # if you have a nvidia gpu on worker nodes
    # adjust batch_size based on the VRAM available on the GPU
    ds = ds.map_batches(TextEmbedder, num_gpus=1, concurrency=2, batch_size=64) # 2 nodes with 1 GPU each
else:
    ds = ds.map_batches(TextEmbedder, concurrency=2, batch_size=64) # either cpu or mps (on a mac)

### Deploying at scale
- The batch size for encoding can be adjusted based on the available memory and performance requirements.
- The `device` parameter ensures that the model runs on the correct device (CPU, GPU, or MPS).
- The `concurrency` parameter controls how many batches are processed in parallel. If there are 2 nodes with 1 GPU each or 1 node with 2 GPUs, then set concurrency = 2 and num_gpus=1.
- map_batches() is a lazy function and not executed until needed (example, using take or show).

Run inference on a batch of 128 rows. This will return a batch of 128 rows with the embeddings added to the caller's machine.

In [8]:
# Run inference on a batch of 128 rows for testing.
ds.take_batch(128)

2025-07-11 07:00:56,090	INFO logging.py:295 -- Registered dataset logger for dataset dataset_4_0
2025-07-11 07:00:56,094	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2025-07-11_06-47-50_390429_98374/logs/ray-data
2025-07-11 07:00:56,095	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> AllToAllOperator[Repartition] -> ActorPoolMapOperator[MapBatches(TextEmbedder)] -> LimitOperator[limit=128]


Running 0: 0.00 row [00:00, ? row/s]

2025-07-11 07:00:56,105	INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 2 (reason=scaling to min size, running=0, restarting=0, pending=0)


- Repartition 1: 0.00 row [00:00, ? row/s]

Split Repartition 2:   0%|                                                                                    …

- MapBatches(TextEmbedder) 3: 0.00 row [00:00, ? row/s]

- limit=128 4: 0.00 row [00:00, ? row/s]

2025-07-11 07:01:12,263	INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=consumed all inputs; running=1, restarting=0, pending=0)
2025-07-11 07:01:12,377	INFO streaming_executor.py:227 -- ✔️  Dataset dataset_4_0 execution finished in 16.28 seconds


{'text': array(['"QT @user In the original draft of the 7th book, Remus Lupin survived the Battle of Hogwarts. #HappyBirthdayRemusLupin"',
        '"Ben Smith / Smith (concussion) remains out of the lineup Thursday, Curtis #NHL #SJ"',
        'Sorry bout the stream last night I crashed out but will be on tonight for sure. Then back to Minecraft in pc tomorrow night.',
        "Chase Headley's RBI double in the 8th inning off David Price snapped a Yankees streak of 33 consecutive scoreless innings against Blue Jays",
        '@user Alciato: Bee will invest 150 million in January, another 200 in the Summer and plans to bring Messi by 2017"',
        "@user LIT MY MUM 'Kerry the louboutins I wonder how many Willam owns!!! Look Kerry Warner Wednesday!'",
        '"\\"""" SOUL TRAIN\\"""" OCT 27 HALLOWEEN SPECIAL ft T.dot FINEST rocking the mic...CRAZY CACTUS NIGHT CLUB ..ADV ticket $10 wt out costume $15..."',
        'So disappointed in wwe summerslam! I want to see john cena wins his 16t