# Batch Image Object Detection using Ray Data

In this example, we will demostrate how to use Ray Data to do batch object detection with a PyTorch model on a large set of images.

# Walkthrough

## Vanilla Pyorch exmaple

Let's take a look at the [vanilla PyTorch example](https://pytorch.org/vision/stable/models.html#object-detection) first. 

In this example, we use a pre-trained PyTorch model to detect the objects on a single image. For simplicity, we skip the result visualization in this example.

```python
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights

img = read_image("/path/to/your/image.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
print(labels)
```


## Parallelizing with Ray Data

Then we want to apply this object detection model to a large set of images. We can use Ray Data to scale the model.

### Reading Image Data 

First, we use `ray.data.read_images` API to read image data from S3.  The directory structure of the dataset is`<s3_url/{label_id}/{*.JPEG}>`. So we use Partitioning utility to load in all the images for all labels. And we use the `parallelism` argument to specify the number of distributed tasks that are used for reading data.

In [2]:
import ray
from ray.data.datasource.partitioning import Partitioning

s3_uri = "s3://anonymous@air-example-data-2/imagenette2/val/"

# The S3 directory structure is {s3_uri}/{class_id}/{*.JPEG}
partitioning = Partitioning("dir", field_names=["class"], base_dir=s3_uri)

ds = ray.data.read_images(s3_uri, parallelism=4, partitioning=partitioning, mode="RGB")

2023-05-08 14:01:11,335	INFO worker.py:1607 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m

Learn more here: https://docs.ray.io/en/master/data/faq.html#what-is-strict-mode[0m


### Understand Bathces

To boost performance with hardware vectorization, we usually do inference in batches. In Ray Data, a batch is by default defined as a `Dict[str, np.ndarray]`. 

In this case, the dict will have only one key named "image", and the value is an array of images represented in the `np.ndarray` format.

In the following code snippet, we use the `take_batch` API to get a single batch and inspect its internal data structure.

In [3]:
single_batch = ds.take_batch(batch_size=4)
print(single_batch["image"].shape, single_batch["image"][0].shape)

2023-05-08 14:01:13,315	INFO streaming_executor.py:91 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage]
2023-05-08 14:01:13,315	INFO streaming_executor.py:92 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-05-08 14:01:13,316	INFO streaming_executor.py:94 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`


Running 0:   0%|          | 0/4 [00:00<?, ?it/s]

2023-05-08 14:03:31,694	INFO streaming_executor.py:149 -- Shutting down <StreamingExecutor(Thread-7, started daemon 13338013696)>.


(4,) (375, 500, 3)


### Do batch inference with Ray Data

Then we will demostrate how to use Ray Data to do batch inference in a distributed cluster.

The first thing is to package the model in a Python class. The class is mainly consist of two parts. In the `__init__` constructor, we package the code that loads and initializes the model. And in the `__call__` method, we package the code that do inference for each batch.

In [4]:
from typing import Dict

import numpy as np
from torchvision import transforms
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
                                          fasterrcnn_resnet50_fpn_v2)


class ObjecytDetectionModel:
    def __init__(self):
        # Define the model loading and initialization code in `__init__`
        self.weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        self.model = fasterrcnn_resnet50_fpn_v2(
            weights=self.weights,
            box_score_thresh=0.9,
        )
        self.model.eval()
        # Note, since the data in the batch input is `np.ndarray`s, 
        # we need `transforms.ToTensor` to convert the data to torch tensors.
        self.prepross = transforms.Compose(
            [transforms.ToTensor(), self.weights.transforms()]
        )

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        # Define the per-batch inference code in `__call__`
        # Preprocess the images.
        batch = [self.prepross(image) for image in batch["image"]]
        # Do inference on the images.
        predictions = self.model(batch)
        # Get the inferred labels and convert it to a np.ndarray
        labels = np.array(
            [
                ",".join(
                    [self.weights.meta["categories"][i] for i in prediction["labels"]]
                )
                for prediction in predictions
            ],
            dtype="S",
        )
        # `__call__` also returns a `Dict[str, np.ndarray]`.
        return {"labels": labels}

Then, we use the `map_batches` API to apply the model to the image data set. Here, `compute` argument indicates the number of concurrent models, and the `batch_size` argument indicates the number of images in each batch.

In [5]:
ds = ds.map_batches(
    ObjecytDetectionModel,
    compute=ray.data.ActorPoolStrategy(size=4),
    batch_size=16,
)

Lastly, we can use `ds.show(4)` to inspect samples in the result. 

In [6]:
ds.show(5)

2023-05-08 14:03:32,642	INFO datastream.py:2085 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2023-05-08 14:03:32,646	INFO streaming_executor.py:91 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[ReadImage->MapBatches(ObjecytDetectionModel)]
2023-05-08 14:03:32,647	INFO streaming_executor.py:92 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-05-08 14:03:32,647	INFO streaming_executor.py:94 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
2023-05-08 14:03:32,661	INFO actor_pool_map_operator.py:114 -- ReadImage->MapBatches(ObjecytDetectionModel): Waiting for 4 pool actors to start...


Running 0:   0%|          | 0/4 [00:00<?, ?it/s]