# Scaling Batch Inference with Ray Data

This template is a quickstart to using [Ray Data](https://docs.ray.io/en/latest/data/data.html) for batch inference. Ray Data is one of many libraries under the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html). See [this blog post](https://www.anyscale.com/blog/model-batch-inference-in-ray-actors-actorpool-and-datasets) for more information on why and how you should perform batch inference with Ray!

This template walks through GPU batch prediction on an image dataset using a PyTorch model, but the framework and data format are there just to help you build your own application!

At a high level, this template will:
1. [Load your dataset using Ray Data.](https://docs.ray.io/en/latest/data/creating-datastreams.html)
2. [Preprocess your dataset before feeding it to your model.](https://docs.ray.io/en/latest/data/transforming-datastreams.html)
3. [Initialize your model and perform inference on a shard of your dataset with a remote actor.](https://docs.ray.io/en/latest/data/transforming-datastreams.html#callable-class-udfs)
4. [Save your prediction results.](https://docs.ray.io/en/latest/data/api/input_output.html)

> Slot in your code below wherever you see the ✂️ icon to build a many model training Ray application off of this template!

In [None]:
import torch
import numpy as np
import tempfile
from typing import Dict

import ray


>✂️ Play around with these values!
>
>For example, for a cluster with 4 GPU nodes, you may want 4 workers, each using 1 GPU.
>Be sure to stay within the resource constraints of your Ray Cluster if autoscaling is not enabled.
>You can check the available resources in your Ray Cluster with: `ray status`

In [None]:
!ray status

In [None]:
NUM_WORKERS: int = 4
NUM_GPUS_PER_WORKER: float = 1


```{tip}
Try setting `NUM_GPUS_PER_WORKER` to a fractional amount! This will leverage Ray's fractional resource allocation, which means you can schedule multiple batch inference workers to happen on the same GPU.
```

> ✂️ Replace this function with logic to load your own data with Ray Data.

In [None]:
def load_ray_dataset() -> ray.data.Datastream:
    from ray.data.datasource.partitioning import Partitioning

    s3_uri = "s3://anonymous@air-example-data-2/imagenette2/val/"
    partitioning = Partitioning("dir", field_names=["class"], base_dir=s3_uri)
    ds = ray.data.read_images(
        s3_uri, size=(256, 256), partitioning=partitioning, mode="RGB"
    )
    return ds


In [None]:
ds = load_ray_dataset()
ds.schema()


> ✂️ Replace this function with your own data preprocessing logic.

In [None]:
def preprocess(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    from torchvision import transforms

    def to_tensor(batch: np.ndarray) -> torch.Tensor:
        tensor = torch.as_tensor(batch, dtype=torch.float)
        # (B, H, W, C) -> (B, C, H, W)
        tensor = tensor.permute(0, 3, 1, 2).contiguous()
        # [0., 255.] -> [0., 1.]
        tensor = tensor.div(255)
        return tensor

    transform = transforms.Compose(
        [
            transforms.Lambda(to_tensor),
            transforms.CenterCrop(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return {"image": transform(batch["image"]).numpy()}


In [None]:
ds = ds.map_batches(preprocess, batch_format="numpy")
ds.schema()


> ✂️ Replace parts of this Callable class with your own model initialization and inference logic.

In [None]:
class PredictCallable:
    def __init__(self):
        # <Replace this with your own model initialization>
        from torchvision import models

        self.model = models.resnet152(pretrained=True)
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        # <Replace this with your own model inference logic>
        input_data = torch.as_tensor(batch["image"], device=self.device)
        with torch.no_grad():
            result = self.model(input_data)
        return {"predictions": result.cpu().numpy()}


Now, perform batch prediction using Ray Data! Ray Data will perform model inference using `NUM_WORKERS` copies of the `PredictCallable` class you defined.

In [None]:
predictions = ds.map_batches(
    PredictCallable,
    batch_size=128,
    compute=ray.data.ActorPoolStrategy(
        # Fix the number of batch inference workers to a specified value.
        size=NUM_WORKERS,
    ),
    num_gpus=NUM_GPUS_PER_WORKER,
    batch_format="numpy",
)


In [None]:
preds = predictions.materialize()
preds.schema()


In [None]:
preds.take(1)


```{tip}
Play around with the `min_size` and `max_size` parameters to enable autoscaling!
For example, try commenting out `max_size`: this will autoscale up to an infinite number of workers, if you have free resources in the cluster.
```

Shard the predictions into a few partitions, and save each partition to a file!

```{note}
This currently saves to the local filesystem under `/tmp/predictions`, but you could also save to a cloud bucket (e.g., `s3://predictions-bucket`).
```

In [None]:
num_shards = 3

with tempfile.TemporaryDirectory() as temp_dir:
    predictions.repartition(num_shards).write_parquet(f"local://{temp_dir}")
    print(f"Predictions saved to `{temp_dir}`!")
