# Object Detection Batch Inference with PyTorch & torchvision

This example demostrates how to do object detection batch inference at scale with a pre-tained PyTorch model and Ray Data.

Here is what you'll do:
1. Perform object detection on a single image with a pre-trained PyTorch model.
1. Scale the PyTorch model with Ray Data, and perform object detection batch inference on a large set of images.
1. Evaluate the results and save them to S3 or local disk.
1. Learn how to use Ray Data with multiple GPU workers.


## Before You Begin

Install the following dependencies if you haven't already

In [None]:
!pip install torchvision ipywidgets tabulate


## Object Detection on a single Image with PyTorch

First, let's take a look at this [object detection example](https://pytorch.org/vision/stable/models.html#object-detection) from PyTorch's official documentation. 

This example includes the following steps:
1. Download an image file from the Internet.
2. Load and intialize a pre-trained PyTorch model.
3. Apply inference preprocessing transforms.
4. Use the model for inference.
5. Visualize the result.

In [None]:
import requests
from PIL import Image
from torchvision import transforms
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

# 1: Download an image file from the Internet.
url = "https://s3-us-west-2.amazonaws.com/air-example-data/AnimalDetection/JPEGImages/2007_000063.jpg"
img = Image.open(requests.get(url, stream=True).raw)
display(img)

# 2: Load and intialize the model.
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# 3: Apply inference preprocessing transforms.
img = transforms.Compose([transforms.PILToTensor()])(img)
preprocess = weights.transforms()
batch = [preprocess(img)]

# 4: Use the model for interence.
prediction = model(batch)[0]

# 5: Visualize the result.
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
display(im)

## Parallelizing with Ray Data

Then we will scale the previous example to a large set of images. We will use Ray Data to do batch inference in a distributed fashion.

### Loading Image Dataset

First, we use the {meth}`ray.data.read_images <ray.data.read_images>` API to read image data from S3. 

In [None]:
import ray

ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection/JPEGImages")
display(ds.schema())

### Understanding Batching

To boost performance with hardware acceleration, we usually do inference in batches. In Ray Data, a batch is by default defined as a `Dict[str, np.ndarray]`. 

In the following code snippet, we use the {meth}`take_batch <ray.data.Dataset.take_batch>` API to fetch a single batch and inspect its internal data structure. As we can see, the batch is a dict that has one key named "image", and the value is an array of images represented in the `np.ndarray` format.

In [None]:
single_batch = ds.take_batch(batch_size=3)
display(single_batch)

### Batch inference with Ray Data

Then we will demostrate how to convert the PyTorch example to Ray Data, and do batch inference in a distributed cluster.

The first step is to package the model in a Python class. The class mainly consists of two parts. In the `__init__` constructor, we package the code that loads and initializes the model. And in the `__call__` method, we package the code that applies model on each batch.

In [None]:
import numpy as np
from torchvision import transforms
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
                                          fasterrcnn_resnet50_fpn_v2)
from typing import Dict
from ray.data.extensions.tensor_extension import create_ragged_ndarray

class ObjectDetectionModel:
    def __init__(self):
        # Define the model loading and initialization code in `__init__`
        self.weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        self.model = fasterrcnn_resnet50_fpn_v2(
            weights=self.weights,
            box_score_thresh=0.9,
        )
        self.model.eval()
        # Note, since the data in the batch input is `np.ndarray`s, 
        # we need `transforms.ToTensor` to convert the data to torch tensors.
        self.preprossor = transforms.Compose(
            [transforms.ToTensor(), self.weights.transforms()]
        )

    def __call__(self, input_batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        # Define the per-batch inference code in `__call__`
        # Preprocess the images.
        batch = [self.preprossor(image) for image in input_batch["image"]]
        # Do inference on the images.
        predictions = self.model(batch)
        return {
            "image": input_batch["image"],
            "labels": create_ragged_ndarray([pred["labels"].detach().numpy() for pred in predictions]),
            "boxes": create_ragged_ndarray([pred["boxes"].detach().numpy() for pred in predictions]),
        }

Then, we use the `map_batches` API to apply the model to the whole image data set. Here, the `compute` argument indicates the number of concurrent Ray actors that run the model, and the `batch_size` argument indicates the number of images in each batch.

In [None]:
ds = ds.map_batches(
    ObjectDetectionModel,
    compute=ray.data.ActorPoolStrategy(size=4),
    batch_size=4,
)

Lastly, we can use `ds.show` to inspect samples in the result, or the the `ds.write_xxx` APIS to write results to external storages.

In [None]:
batch = ds.take_batch(batch_size=3)


In [None]:
import torch
for image, labels, boxes in zip(batch["image"], batch["labels"], batch["boxes"]):
    image = torch.from_numpy(image.transpose(2,0,1))
    labels = [weights.meta["categories"][i] for i in labels]
    boxes = torch.from_numpy(boxes)
    box = draw_bounding_boxes(image, 
                              boxes,
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
    im = to_pil_image(box)
    display(im)

### Using GPUs

To use GPUs in the previous example, we first need to change to model to use GPU.

```diff
         self.model = fasterrcnn_resnet50_fpn_v2(
             weights=self.weights,
             box_score_thresh=0.9,
-        )
+        ).cuda()
         self.model.eval()
         # Note, since the data in the batch input is `np.ndarray`s,
         # we need `transforms.ToTensor` to convert the data to torch tensors.
```

```diff
     def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
         # Define the per-batch inference code in `__call__`
         # Preprocess the images.
-        batch = [self.prepross(image) for image in batch["image"]]
+        batch = [self.prepross(image).cuda() for image in batch["image"]]
         # Do inference on the images.
         predictions = self.model(batch)
         # Get the inferred labels and convert it to a np.ndarray
```

Then we add the `num_gpus` argument to the `map_batches` API. This will tell Ray to schedule the actors to nodes with GPU resources and assign a GPU to each actor.

```diff
 ds = ds.map_batches(
     ObjecytDetectionModel,
     compute=ray.data.ActorPoolStrategy(size=4),
     batch_size=16,
+    num_gpus=1,
 )
```