# Performing GPU Batch Prediction on Images with a PyTorch Model

In this example, we will introduce how to use the Ray AIR {class}`BatchPredictor <ray.train.batch_predictor.BatchPredictor>` for **large-scale batch inference with multiple GPU workers.**

In particular, we will:
- Load Imagenette dataset from S3 bucket and create a ray dataset.
- Load a pretrained ResNet model and build a checkpoint.
- Define a preprocessor.
- Construct a BatchPredictor with the checkpoint and preprocessor.
- Do batch prediction on multiple GPUs.
- Evaluate the predictions and save results to S3/local disk.

To run this example, you will need to install the following:

In [None]:
!pip install -q "ray[air]" boto3 torch torchvision

[Imagenette](https://github.com/fastai/imagenette) is a subset of Imagenet with 10 classes. First, we use {meth}`ray.data.read_images <ray.data.read_images>` to load the validation set from S3. Since the dataset is already structured with directory names as the labels, we can use the {class}`Partitioning <ray.data.datasource.Partitioning>` API to automatically extract image labels.

In [None]:
import ray
from ray.data.datasource.partitioning import Partitioning

s3_uri = "s3://anonymous@air-example-data-2/imagenette2/val/"

# The S3 directory structure is {s3_uri}/{class_id}/{*.JPEG}
partitioning = Partitioning("dir", field_names=["class"], base_dir=s3_uri)

ds = ray.data.read_images(
    s3_uri, size=(256, 256), partitioning=partitioning, mode="RGB"
)
ds


In [None]:
ds.take(1)


As we can see, each example contains one image tensor of shape (256, 256, 3) and its label. Notice that the label for images are their corresponding directory names (e.g. n01728920). To find the indices of our model output that correspond to these names, we'll need to download a mapping from the s3 bucket (`imagenet_class_index.json`).

In [None]:
# If you want to run the full example, please set this to False
SMOKE_TEST = True


In [None]:
# @title +
if SMOKE_TEST:
    ds = ds.limit(1000)


In [None]:
import json
import boto3
from botocore import UNSIGNED
from botocore.client import Config

# Download mapping file from S3
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
s3.download_file(
    "air-example-data-2",
    "imagenette2/imagenet_class_index.json",
    "/tmp/imagenet_class_index.json",
)

# Build mappings
idx_to_class = json.load(open("/tmp/imagenet_class_index.json", "r"))
class_to_idx = {cls_name: int(index) for index, (cls_name, _) in idx_to_class.items()}
idx_to_class_name = {int(index): text for index, (_, text) in idx_to_class.items()}


Next, let's define a preprocessor to crop and normalize images, as well as convert the class names to indices with the map that we just constructed. 

We'll first use a {class}`TorchVisionPreprocessor <ray.data.preprocessors.TorchVisionPreprocessor>` to crop and normalize the images. Then, we will implement a function to map labels the class name to the respective index. The above data preprocessing logic will be applied to the input dataset before feeding the data into the model.

In [None]:
import torch
import numpy as np
import pandas as pd
from torchvision import transforms
from ray.data.preprocessors import BatchMapper, TorchVisionPreprocessor, Chain
from typing import Dict


def to_tensor(batch: np.ndarray) -> torch.Tensor:
    tensor = torch.as_tensor(batch, dtype=torch.float)
    # (B, H, W, C) -> (B, C, H, W)
    tensor = tensor.permute(0, 3, 1, 2).contiguous()
    # [0., 255.] -> [0., 1.]
    tensor = tensor.div(255)
    return tensor


transform = transforms.Compose(
    [
        transforms.Lambda(to_tensor),
        transforms.CenterCrop(224),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Accelerate image processing with batched transformations
image_preprocessor = TorchVisionPreprocessor(
    columns=["image"], transform=transform, batched=True
)

# Map the image labels from strings to integer ids
def map_labels(batch: np.ndarray) -> np.ndarray:
    batch["label"] = np.vectorize(class_to_idx.__getitem__)(batch["class"])
    return batch

processed_ds = ds.map_batches(image_preprocessor).map_batches(map_labels, batch_format="numpy")


## Distributed Inference

As the last step, we will do inference with a {class}`~ray.train.torch.TorchPredictor`. By using Ray Datasets, we can distribute the inference workload across multiple workers and run prediction on multiple shards of data in parallel. You can find more details in [Using Predictors for Inference](air-predictors).

For the demo, we'll directly load a pretrained ResNet model from `torchvision.models`.

We specified several parameters in `map_batches`:
- `feature_columns` specifies which columns are required for the model.
- The columns in `keep_columns` will be returned together with the prediction results. For example, you can keep image labels for evaluation later.
- `map_batches` uses CPUs for inference by default, please specify `num_gpus` if you want to use GPUs.
- We specify an `ActorPoolStrategy` indicating how many workers we want to use for inference.

In [None]:
from torchvision import models
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchPredictor

# Load the pretrained resnet model and construct a TorchPredictor
model = models.resnet152(pretrained=True)
predictor = TorchPredictor(model)

predictions = processed_ds.map_batches(predictor, 
                             compute=ray.data.ActorPoolStrategy(4, 4), 
                             batch_size=128, 
                             num_gpus=1,
                             fn_kwargs={"feature_columns": ["image"], "keep_columns": ["label"]})

predictions.cache()


## Evaluating Prediction Accuracy

Our `predictions` dataset contains a column of model output with key `"predictions"`, and all columns specified in `keep_columns`.

In this example, the output of the ResNet model is a 1000-dimensional tensor containing the logits of each class. We'll measure accuracy with Top-1 and Top-5 accuracy.
(Top-N accuracy: The percentage of predictions where the true label falls in the top N predicted classes.)

In [None]:
def calculate_matches(batch: pd.DataFrame) -> pd.DataFrame:
    batch["top_5_pred"] = batch["predictions"].apply(lambda x: np.argsort(-x)[:5])
    batch["top_5_pred_name"] = batch["top_5_pred"].map(
        lambda x: [idx_to_class_name[idx] for idx in x]
    )
    batch["top_5_match"] = batch.apply(lambda x: x["label"] in x["top_5_pred"], axis=1)
    batch["top_1_match"] = batch.apply(
        lambda x: x["label"] == x["top_5_pred"][0], axis=1
    )
    return batch


predictions = predictions.map_batches(calculate_matches, batch_format="pandas")
print("Top-1 accuracy: ", predictions.mean(on="top_1_match"))
print("Top-5 accuracy: ", predictions.mean(on="top_5_match"))


In [None]:
from IPython.display import display
from PIL import Image

# Take an example from the model's prediction
sample_image = ds.take(1)[0]
sample_pred = predictions.take(1)[0]
display(Image.fromarray(sample_image["image"]))
print("Top-1 Matched: ", sample_pred["top_1_match"])
print("Top-5 Predictions: ", sample_pred["top_5_pred_name"])


## Save Prediction Results

There are a few options for saving your prediction results:
- You can call `ds.repartition(n)` to split your prediction results into n partitions, then n files will be created with `write_parquet()` later.
- You can either store files to your local disk or S3 bucket by passing local path or S3 uri to `write_parquet()`.
- Other output file formats are described here: [Ray Data Input/Output](https://docs.ray.io/en/latest/data/api/input_output.html)


In [None]:
predictions.repartition(1).write_parquet("local://tmp/single_parquet")
# >>> /tmp/single_parquet/d757569dfb2845589b0ccbcb263e8cc3_000000.parquet

predictions.repartition(3).write_parquet("local://tmp/multiple_parquet")
# >>> /tmp/multiple_parquet/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000000.parquet
# >>> /tmp/multiple_parquet/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000001.parquet
# >>> /tmp/multiple_parquet/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000002.parquet

# You can also save results to S3 by replacing local path to S3 URI
# predictions.write_parquet(YOUR_S3_BUCKET_URI)
