# Batch Prediction with Pytorch Pretrained Model

In this example, we will introduce how to use {class}`BatchPredictor <ray.train.BatchPredictor>` for large-scale efficient batch inference.

Next we will load a pretrained ResNet-152 model for image classification. We choose [Imagenette](https://github.com/fastai/imagenette) dataset for demo, which is a subset of Imagenet with 10 easily classified classes(tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).

First we use {meth}`ray.data.read_images <ray.data.read_images>` to load the validation set from S3. Since the dataset folders are already structured, we can use {class}`Partitioning <ray.data.datasource.Partitioning>` API to automatically extract image labels.

In [1]:
import ray
from ray.data.datasource.partitioning import Partitioning

s3_uri = "s3://air-example-data-2/imagenette2/val/"
partitioning = Partitioning("dir", field_names=["class"], base_dir=s3_uri)
ds = ray.data.read_images(s3_uri, partitioning=partitioning, mode="RGB")
print(ds.schema())

2023-02-11 14:38:07,841	INFO worker.py:1352 -- Connecting to existing Ray cluster at address: 10.0.29.191:6379...
2023-02-11 14:38:08,111	INFO worker.py:1529 -- Connected to Ray cluster. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2023-02-11 14:38:08,117	INFO packaging.py:373 -- Pushing file package 'gcs://_ray_pkg_448ef99926e5e34674f2510a0848d7c6.zip' (0.71MiB) to Ray cluster...
2023-02-11 14:38:08,124	INFO packaging.py:386 -- Successfully pushed file package 'gcs://_ray_pkg_448ef99926e5e34674f2510a0848d7c6.zip'.


image: extension<arrow.py_extension_type<ArrowVariableShapedTensorType>>
class: string


In [None]:
ds = ds.limit(1000)

Notice that the label for images are strings of folder names(e.g. n01728920). Let's download a `imagenet_class_index.json` to construct a mapping from class names to indices:

In [None]:
import boto3

s3 = boto3.resource('s3')
s3.meta.client.download_file('air-example-data-2', 'imagenette2/imagenet_class_index.json', '/tmp/imagenet_class_index.json')

In [3]:
import json

idx_to_class = json.load(open("/tmp/imagenet_class_index.json"))
class_to_idx = {cls_name: int(index) for index, (cls_name, _) in idx_to_class.items()}

Now we define a preprocessor to crop and normalize input images. Likewise, you can define any customized preprocessing function for a {class}`BatchMapper <ray.data.preprocessors.BatchMapper>`. It applies data transformations to the input dataset before feeding it into the model.

In [4]:
import numpy as np
from torchvision import transforms
from ray.data.preprocessors import BatchMapper

def preprocess_image(batch):
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    batch["image"] = np.array([preprocess(img).numpy() for img in batch["image"]])
    batch["label"] = np.array([class_to_idx[cls_name] for cls_name in batch["class"]])
    return batch

preprocessor = BatchMapper(fn=preprocess_image, batch_format="numpy")

## Build a BatchPredictor

A BatchPredictor takes a checkpoint and a predictor class(e.g. `TorchPredictor`, `TensorflowPredictor`) and provides an interface to run batch scoring on Ray datasets. It will distribute inference workload across all the workers when calling `predict()`. You can find more details in [Using Predictors for Inference](https://docs.ray.io/en/latest/ray-air/predictors.html).

Here we directly load a pretrained ResNet model from `torchvision.models`, and construct a TorchCheckpoint with the preprocessor. You can also reload your own Ray AIR checkpoint from your previous experiments. You can find more details about checkpoint loading in {class}`Checkpoint <ray.air.checkpoint.Checkpoint>` api.

In [None]:
from torchvision import models
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

# Load the pretrained resnet model and construct a checkpoint
model = models.resnet152(pretrained=True)
checkpoint = TorchCheckpoint.from_model(model=model, preprocessor=preprocessor)

# Build a BatchPredictor from checkpoint
batch_predictor = BatchPredictor(checkpoint, TorchPredictor)

Call {meth}`predict() <ray.train.batch_predictor.BatchPredictor.predict>` on target datasets:
- `feature_columns` specifies which columns are required for the model. The selected columns will be concatenated to build a batch tensor.
- The columns in `keep_columns` will be returned together with the prediction results. For example, you can keep image labels for evaluation later.
- The BatchPredictor uses CPUs for inference by default, please specify `num_gpus_per_worker` if you want to use GPUs.

In [6]:
predictions = batch_predictor.predict(
    ds, feature_columns=["image"], keep_columns=["label"], batch_size=128, max_scoring_workers=3, num_gpus_per_worker=1
)

2023-02-11 14:38:27,442	INFO batch_predictor.py:184 -- `num_gpus_per_worker` is set for `BatchPreditor`.Automatically enabling GPU prediction for this predictor. To disable set `use_gpu` to `False` in `BatchPredictor.predict`.
Map_Batches: 100%|██████████| 50/50 [00:05<00:00,  9.27it/s]
Map Progress (2 actors 1 pending): 100%|██████████| 9/9 [00:16<00:00,  1.82s/it]


## Result Evaluation and Saving

`BatchPredictor.predict()` will return a ray dataset with a column of model output with key "predictions", and all columns specified in `keep_columns`.

In this example, the output of the ResNet model is a 1000-dimensional tensor containing the logits of each class. We select the top-5 classes with the highest classification probabilities and calculate the Top-1 and Top-5 errors:


In [13]:
def calculate_matches(batch):
    batch["top_5_pred"] = batch["predictions"].apply(lambda x: np.argsort(-x)[:5])
    batch["top_5_match"] = batch.apply(lambda x: x["label"] in x["top_5_pred"], axis=1)
    batch["top_1_match"] = batch.apply(lambda x: x["label"] == x["top_5_pred"][0], axis=1)
    return batch

predictions = predictions.map_batches(calculate_matches)
print("Top-1 error: ", 1 - predictions.mean(on="top_1_match"))
print("Top-5 error: ", 1 - predictions.mean(on="top_5_match"))

Map_Batches: 100%|██████████| 9/9 [00:00<00:00, 139.29it/s]
Shuffle Map: 100%|██████████| 9/9 [00:00<00:00, 613.36it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 88.26it/s]


Top-1 error:  0.15100000000000002


Shuffle Map: 100%|██████████| 9/9 [00:00<00:00, 699.70it/s]
Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 130.66it/s]


Top-5 error:  0.010000000000000009


Save your prediction results:

- You can call `ds.repartition(n)` to split your prediction results into n partitions, then n files will be created with `write_csv()` later.
- You can either store files to your local disk or S3 bucket by passing local path or S3 uri to `write_csv()`.
- Other output file formats are described here: [Ray Data Input/Output](https://docs.ray.io/en/latest/data/api/input_output.html)


In [None]:
predictions.repartition(1).write_csv("/tmp/single_csv")
# >>> /tmp/single_csv/d757569dfb2845589b0ccbcb263e8cc3_000000.csv

predictions.repartition(3).write_csv("/tmp/multiple_csv")
# >>> /tmp/multiple_csv/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000000.csv
# >>> /tmp/multiple_csv/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000001.csv
# >>> /tmp/multiple_csv/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000002.csv

# You can also save results to S3 by replacing local path to S3 URI
# predictions.write_csv(YOUR_S3_BUCKET_URI)