# Scalable Batch Inference with Ray

<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Generic/ray_logo.png" width="20%" loading="lazy">

## About this notebook

### Is this module right for you?

This module presents several approaches for scaling batch inference on Ray. Through hands-on practice with inference on a computer vision task, you will implement and compare different inference architectures to better understand Ray AIR and Ray Core.

To get the most out of this notebook, the following scenarios may apply to you:

* You observe performance bottlenecks when working on batch inference problems in computer vision projects.
* You want to scale or increase throughput of existing batch inference pipelines.
* You wish to explore different architectures for scaling batch inference with Ray AIR and Ray Core.

### Prerequisites

For this notebook you should satisfy the following requirements:

* Practical Python and machine learning experience.
* Familiarity with batch inference in ML.
* Familiarity with Ray and Ray AIR equivalent to completing these training modules:
  * [Overview of Ray](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Overview_of_Ray.ipynb)
  * [Introduction to Ray AIR](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Introduction_to_Ray_AIR.ipynb)

### Learning objectives

* Understand common design patterns for distributed batch inference.
* Implement scalable batch inference with Ray.
* Extend each approach by tuning performance.
* Compare scalable batch inference architectures on Ray to evaluate which is most relevant to your work.

### What will you do?

* Learn about three distributed batch inference design patterns with Ray.
* Get to know the inference task.
  * Semantic (image) segmentation using the SegFormer model.
* Implement sequential inference.
* Implement distributed inference patterns.
  * Inference with Ray AIR Datasets and **BatchPredictor** abstractions.
* Compare approaches to identify situations best fit for each.

## Part 1: Ray design patterns for scaling batch inference

The ultimate goal for machine learning models is often to generate predictions on a set of unseen data. In this notebook, you focus on the inference stage of the ML workflow and explore different approaches to scaling it.

Ray Core and Ray AIR provide APIs that allow you to perform batch inference at scale, processing millions of examples and offering various performance tuning options.

|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/example_ml_workflow.png" width="70%" loading="lazy">|
|:--|
|An example of a machine learning workflow that starts with reading raw data and preprocessing it. These steps are followed by training and tuning that produce a trained model. This model is then used for inference, often on large datasets.|

### What is (batch) inference?

<div class="alert alert-info">
  <strong>Batch inference</strong> (also known as offline inference): is the process of generating predictions on a large set or "batch" of data.
</div>

Unlike *online inference* where predictions are generated as each observation is produced, batch inference generates predictions over a large number of input data when immediate response is not required or feasible. 

For example, batch inference is relevant when generating weekly product recommendations using historical customer data or sales forecasting using time-aggregated observations.

|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/batch_inference.png" width="70%" loading="lazy">|
|:--|
|Batch inference is the process of applying a trained model to a batch of data to generate predictions.|

In a non-distributed setting, inference executes sequentially. The model processes incoming batches of data one at a time, limiting performance to a single machine or GPU. Below, you will learn about three approaches for distributing batch inference on Ray.

### Batch inference using Ray AIR BatchPredictor

Ray AIR [BatchPredictor](https://docs.ray.io/en/latest/ray-air/predictors.html#batch-prediction) is a utility for large-scale, distributed batch inference. `BatchPredictor` has out-of-the-box features:

* supports various predictors like [TorchPredictor](https://docs.ray.io/en/latest/ray-air/api/doc/ray.train.torch.TorchPredictor.html#ray.train.torch.TorchPredictor), [HuggingFacePredictor](https://docs.ray.io/en/latest/ray-air/api/doc/ray.train.huggingface.HuggingFacePredictor.html#ray.train.huggingface.HuggingFacePredictor) or [XGBoostPredictor](https://docs.ray.io/en/latest/ray-air/api/doc/ray.train.xgboost.XGBoostPredictor.html#ray.train.xgboost.XGBoostPredictor))
* it handles framework-native batch conversions
* it has options to resume operations from AIR checkpoint to prediction, selection / keep columns, etc.

`BatchPredictor` takes in two components:

* **`Checkpoint`**. A trained model, could be from training or tuning step.
* **`Predictor`**. A class that loads models from `Checkpoint` to perform inference; supports framework-specific predictors (e.g. TorchPredictor and TensorflowPredictor).

Once instantiated, BatchPredictor can call `predict()` on a Ray Dataset. [Ray Datasets](https://docs.ray.io/en/latest/data/dataset.html#datasets) are the standard way to load and exchange data in Ray AIR. Datasets load and preprocess data for parallel compute, internally handling operations like batching, pipelining, autoscaling the actor pool, and memory management.



|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/air_batchpredictor.png" width="70%" loading="lazy">|
|:--|
|Ray Datasets parallelize data loading, preprocessing, and batching. Ray AIR `BatchPredictor` takes both `Checkpoint` and `Predictor` objects to call `predict()` on a Ray Dataset for distributed batch inference.|

These high-level abstractions automate the challenging aspects of scaling batch inference in exchange for less direct control over the way Ray distributes.

<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/code_batchpredictor.png" width="70%" loading="lazy">

## Part 2: Batch inference example using computer vision transformers

To demonstrate the three design patterns introduced in the previous section, you will apply each approach on a computer vision task: semantic segmentation.

Semantic segmentation, similar to object detection, involves assigning labels to objects in a scene pixel-by-pixel. In this hands-on example, you will run batch inference on image data by using a pretrained model to generate predictions.

### Data

#### MIT ADE20K - scene parsing benchmark

The [MIT ADE20K Dataset](http://sceneparsing.csail.mit.edu/) (also known as "SceneParse150") provides the largest open source dataset for scene parsing. It is often used as a standard for assessing semantic segmentation model performance due to its high-quality annontations. For this example, you will use the unlabeled test data to implement different batch inference architectures.

|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/scene.png" width="70%" loading="lazy">|
|:--|
|Unannotated scene image from MITADE20K on the left. Pixel-by-pixel predictions on the right. [*Date accessed: November 10, 2022*](https://github.com/CSAILVision/semantic-segmentation-pytorch)|

Dataset highlights

* 20k annotated, scene-centric training images
* 3.3k unlabeled test images
* 150 [semantic categories](https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit?usp=sharing) (such as person, car, bed, sky, etc.)

### Model

#### SegFormer - transformer-based framework for semantic segmentation

[SegFormer](https://arxiv.org/pdf/2105.15203.pdf) is an effective semantic segmentation method based on a *transformer* architecture. [Transformers](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) are a type of deep learning architecture that process sequential data via a series of self-attention layers and then transform them via a feedforward neural network.

What sets SegFormer apart from previous transformer-based approaches are two key features:

1. A hierarchically structured transformer encoder which does not depend on positional encoding that avoids interpolation when training and testing resolutions differ.
2. A lightweight MLP layer that avoids complex decoders.

You will use a pretrained SegFormer model finetuned on [MITADE20K](http://sceneparsing.csail.mit.edu/) to perform batch inference.

|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/segformer_architecture.png" width="70%" loading="lazy">|
|:--|
|SegFormer architecture showcasing the hierarical transformer encoder and all-MLP decoder. [*Date accessed: November 10, 2022*](https://arxiv.org/pdf/2105.15203.pdf).|


## Part 3: Sequential batch inference

In order to establish familiarity with this batch inference task, you will implement a basic approach with one worker that generates predictions on batches sequentially. To get set up, the semantic segmentation example requires the following steps:

1. Load the pretrained [SegFormer](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512) model.
2. Load the [feature extractor](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/segformer#transformers.SegformerFeatureExtractor) (preprocessor for scene data).
3. Load [SceneParse150](https://huggingface.co/datasets/scene_parse_150) dataset.
4. Run batch inference on images from the test set.

|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/single_sequential_timeline.png" width="90%" loading="lazy">|
|:--|
|Timeline of sequential batch inference using a single worker. Tasks can vary in runtime due variations in complexity, data size, and more. |

### Install libraries

In [None]:
! pip install -U ray==2.3.0 transformers==4.26.1 torch==1.13.1 datasets==2.10.1

### Set up necessary imports and utilities

In [None]:
import json
import torch
import numpy as np
import pandas as pd

from typing import Union
from datasets.arrow_dataset import Dataset
from huggingface_hub import hf_hub_download
from matplotlib import pyplot as plt
from PIL import Image
from PIL.JpegImagePlugin import JpegImageFile

# Set the seed to a fixed value for reproducibility.
torch.manual_seed(201)

### Load the model components from the HuggingFace Hub

From the [Hugging Face Hub](https://huggingface.co/docs/hub/index), retrieve the pretrained SegFormer model by specifying the model name and [label files](https://huggingface.co/datasets/huggingface/label-files/blob/main/ade20k-id2label.json) which map indices to semantic categories.

#### Load label mappings

In [None]:
# https://huggingface.co/datasets/huggingface/label-files/blob/main/ade20k-id2label.json
def get_labels():
    repo_id = "huggingface/label-files"
    filename = "ade20k-id2label.json"
    id2label = json.load(
        open(
            hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"),
            "r",
        )
    )

    id2label = {int(k): v for k, v in id2label.items()}
    label2id = {v: k for k, v in id2label.items()}
    return id2label, label2id

In [None]:
id2label, label2id = get_labels()

print(f"Total number of labels: {len(id2label)}")
print(f"Example labels: {list(id2label.values())[:5]}")

The utility function `get_labels` fetches two dictionary mappings from [Hugging Face](https://huggingface.co/datasets/huggingface/label-files/blob/main/ade20k-id2label.json), `id2label` and `label2id`, which are used to convert between numerical and string labels for the 150 available [semantic categories](https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit#gid=0) of objects.

#### Load SegFormer

In [None]:
from transformers import SegformerForSemanticSegmentation

In [None]:
MODEL_NAME = "nvidia/segformer-b0-finetuned-ade-512-512"

segformer = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_NAME, id2label=id2label, label2id=label2id
)

print(f"Number of model parameters: {segformer.num_parameters()/(10**6):.2f} M")

The [Hugging Face Hub](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512) makes available many variations on SegFormer. Here, you specify a version finetuned on the MITADE20K (SceneParse150) dataset on images with a 512 x 512 resolution.

Note: This "b0" model is the smallest, with [other options](https://huggingface.co/nvidia/segformer-b5-finetuned-ade-640-640) ranging up to and including "b5". Keep this in mind as something to experiment with when comparing different batch inference architectures later on.

#### Load the feature extractor

In [None]:
from transformers import SegformerFeatureExtractor

In [None]:
segformer_feature_extractor = SegformerFeatureExtractor.from_pretrained(
    MODEL_NAME, reduce_labels=True
)
segformer_feature_extractor

[Feature extractors](https://huggingface.co/docs/transformers/main_classes/feature_extractor) preprocess input features (e.g. image data) by normalizing, resizing, padding, and converting raw images into the shape expected by SegFormer.

The [`reduce_labels`](https://huggingface.co/docs/transformers/model_doc/segformer#segformer) flag ensures that the background of an image (anything that is not explicitly an object) isn't included when computing loss. 

### Load dataset

#### Set up necessary imports

In [None]:
from datasets import load_dataset

In [None]:
SMALL_DATA = True

<div class="alert alert-warning">
  <strong>SMALL_DATA</strong>: a flag to download a subset (160 images) of the available test data. Defaults to True. Set to False (recommended) to work with the full test data (3352 images).
</div>

If you set `SMALL_DATA` to `False`, expect it to take some time (depending on your connection download speed) because you are downloading all test images to your local machine or cluster.

#### Load SceneParse150

In [None]:
DATASET_NAME = "scene_parse_150"

# Load data from the Hugging Face datasets repository.
if SMALL_DATA:
    train_dataset = load_dataset(DATASET_NAME, split="train[:10]")
    test_dataset = load_dataset(DATASET_NAME, split="test[:160]")
else:
    train_dataset = load_dataset(DATASET_NAME, split="train[:10]")
    test_dataset = load_dataset(DATASET_NAME, split="test")

The two datasets serve different purposes:

* **`train_dataset`**  
    * Retrieve a small sample of images for visualization purposes only. Training samples include ground-truth, annotated image regions. Full training dataset contains 20210 images.
* **`test_dataset`**  
    * Used for batch inference purposes. Test samples do not contain ground-truth labels. Full test dataset contains 3352 images.

In [None]:
train_dataset

In [None]:
def convert_image_to_rgb(data_item):
    if data_item["image"].mode != "RGB":
        data_item["image"] = data_item["image"].convert(mode="RGB")

    return data_item

In [None]:
test_dataset = test_dataset.map(convert_image_to_rgb)
test_dataset

Each sample contains three components:
* **`image`** 
    * The PIL image.
* **`annotation`**  
    * Human annotations of image regions (annotation mask is `None` in testing set).
* **`category`**  
    * Category of the scene generally (e.g. driveway, voting booth, dairy_outdoor).

#### Display example images

In [None]:
# A colormap for visualizing segmentation results.
# https://github.com/tensorflow/models/blob/3f1ca33afe3c1631b733ea7e40c294273b9e406d/research/deeplab/utils/get_dataset_colormap.py#L51
# (date accessed: Nov 9th, 2022)
ade_palette = np.array(
    [
        [0, 0, 0],
        [120, 120, 120],
        [180, 120, 120],
        [6, 230, 230],
        [80, 50, 50],
        [4, 200, 3],
        [120, 120, 80],
        [140, 140, 140],
        [204, 5, 255],
        [230, 230, 230],
        [4, 250, 7],
        [224, 5, 255],
        [235, 255, 7],
        [150, 5, 61],
        [120, 120, 70],
        [8, 255, 51],
        [255, 6, 82],
        [143, 255, 140],
        [204, 255, 4],
        [255, 51, 7],
        [204, 70, 3],
        [0, 102, 200],
        [61, 230, 250],
        [255, 6, 51],
        [11, 102, 255],
        [255, 7, 71],
        [255, 9, 224],
        [9, 7, 230],
        [220, 220, 220],
        [255, 9, 92],
        [112, 9, 255],
        [8, 255, 214],
        [7, 255, 224],
        [255, 184, 6],
        [10, 255, 71],
        [255, 41, 10],
        [7, 255, 255],
        [224, 255, 8],
        [102, 8, 255],
        [255, 61, 6],
        [255, 194, 7],
        [255, 122, 8],
        [0, 255, 20],
        [255, 8, 41],
        [255, 5, 153],
        [6, 51, 255],
        [235, 12, 255],
        [160, 150, 20],
        [0, 163, 255],
        [140, 140, 140],
        [250, 10, 15],
        [20, 255, 0],
        [31, 255, 0],
        [255, 31, 0],
        [255, 224, 0],
        [153, 255, 0],
        [0, 0, 255],
        [255, 71, 0],
        [0, 235, 255],
        [0, 173, 255],
        [31, 0, 255],
        [11, 200, 200],
        [255, 82, 0],
        [0, 255, 245],
        [0, 61, 255],
        [0, 255, 112],
        [0, 255, 133],
        [255, 0, 0],
        [255, 163, 0],
        [255, 102, 0],
        [194, 255, 0],
        [0, 143, 255],
        [51, 255, 0],
        [0, 82, 255],
        [0, 255, 41],
        [0, 255, 173],
        [10, 0, 255],
        [173, 255, 0],
        [0, 255, 153],
        [255, 92, 0],
        [255, 0, 255],
        [255, 0, 245],
        [255, 0, 102],
        [255, 173, 0],
        [255, 0, 20],
        [255, 184, 184],
        [0, 31, 255],
        [0, 255, 61],
        [0, 71, 255],
        [255, 0, 204],
        [0, 255, 194],
        [0, 255, 82],
        [0, 10, 255],
        [0, 112, 255],
        [51, 0, 255],
        [0, 194, 255],
        [0, 122, 255],
        [0, 255, 163],
        [255, 153, 0],
        [0, 255, 10],
        [255, 112, 0],
        [143, 255, 0],
        [82, 0, 255],
        [163, 255, 0],
        [255, 235, 0],
        [8, 184, 170],
        [133, 0, 255],
        [0, 255, 92],
        [184, 0, 255],
        [255, 0, 31],
        [0, 184, 255],
        [0, 214, 255],
        [255, 0, 112],
        [92, 255, 0],
        [0, 224, 255],
        [112, 224, 255],
        [70, 184, 160],
        [163, 0, 255],
        [153, 0, 255],
        [71, 255, 0],
        [255, 0, 163],
        [255, 204, 0],
        [255, 0, 143],
        [0, 255, 235],
        [133, 255, 0],
        [255, 0, 235],
        [245, 0, 255],
        [255, 0, 122],
        [255, 245, 0],
        [10, 190, 212],
        [214, 255, 0],
        [0, 204, 255],
        [20, 0, 255],
        [255, 255, 0],
        [0, 153, 255],
        [0, 41, 255],
        [0, 255, 204],
        [41, 0, 255],
        [41, 255, 0],
        [173, 0, 255],
        [0, 245, 255],
        [71, 0, 255],
        [122, 0, 255],
        [0, 255, 184],
        [0, 92, 255],
        [184, 255, 0],
        [0, 133, 255],
        [255, 214, 0],
        [25, 194, 194],
        [102, 255, 0],
        [92, 0, 255],
    ]
)

In [None]:
def prepare_pixels_with_segmentation(
    image: JpegImageFile, segmentation_maps: Union[torch.Tensor, np.array]
):
    segmentation_maps = np.array(segmentation_maps)
    color_segments = np.zeros(
        (segmentation_maps.shape[0], segmentation_maps.shape[1], 3), dtype=np.uint8
    )
    for label, color in enumerate(ade_palette):
        color_segments[segmentation_maps == label, :] = color
    color_segments = color_segments[..., ::-1]  # convert to BGR
    pixels_with_segmentation = np.array(image) * 0.5 + color_segments * 0.5
    return pixels_with_segmentation.astype(np.uint8)

In [None]:
def display_example_images(dataset: Dataset, n: int = 2):
    fig, axes = plt.subplots(nrows=n, ncols=n, figsize=(10, 10))
    fig.set_tight_layout(True)
    for i, j in enumerate(
        np.random.choice(dataset.num_rows, size=(n * n), replace=False)
    ):
        image_with_pixels = prepare_pixels_with_segmentation(
            image=dataset[int(j)]["image"],
            segmentation_maps=np.array(dataset[int(j)]["annotation"]),
        )
        axes[int(i / n), i % n].imshow(image_with_pixels)
        axes[int(i / n), i % n].axis("off")

In [None]:
# Try running this multiple times!
display_example_images(train_dataset)

### Run sequential inference on 1 batch

#### Define inference logic

This `predict` function forms the basis for the inference step, and you will reuse variations of this function multiple times throughout each approach for batch inference.

In [None]:
def predict(
    model: SegformerForSemanticSegmentation,
    feature_extractor: SegformerFeatureExtractor,
    images: list[JpegImageFile],
) -> list[np.array]:
    # Set the device on which PyTorch will run.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)  # Move the model to specified device.
    model.eval()  # Set the model in evaluation mode on test data.

    # The feature extractor processes raw images.
    inputs = feature_extractor(images=images, return_tensors="pt")

    # The model is applied to input images in the inference step.
    with torch.no_grad():
        outputs = model(pixel_values=inputs.pixel_values.to(device))

    # Post-process the output for display.
    image_sizes = [image.size[::-1] for image in images]
    segmentation_maps_postprocessed = (
        feature_extractor.post_process_semantic_segmentation(
            outputs=outputs, target_sizes=image_sizes
        )
    )

    # Return list of segmentation maps detached from the computation graph.
    return [j.detach().cpu().numpy() for j in segmentation_maps_postprocessed]

#### Prepare 1 batch of 16 images

In [None]:
def get_image_indices(dataset: Dataset, n: int):
    image_indices = np.random.choice(dataset.num_rows, size=n, replace=False)
    return [int(i) for i in image_indices]

In [None]:
BATCH_SIZE = 16

# Get BATCH_SIZE randomly shuffled image IDs from the test dataset.
image_indices = get_image_indices(dataset=test_dataset, n=BATCH_SIZE)
image_indices

In [None]:
# Create a list of images by extracting images from random indices sampled from the test data.
batch = [test_dataset[i]["image"] for i in image_indices]
batch

#### Run batch inference

In [None]:
segmentation_maps = predict(
    model=segformer,
    feature_extractor=segformer_feature_extractor,
    images=batch,
)

In [None]:
segmentation_maps[0]

Performing batch inference outputs a list of segmentation maps. Each element in the segmentation map array represents the [semantic category](https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit#gid=0) of the corresponding pixel in the input image.

Together, you can visualize these predicted segmentation maps by overlaying them onto the original image to see defined regions of objects.

#### Visualize example predictions

In [None]:
def visualize_predictions(image: JpegImageFile, segmentation_maps: torch.Tensor):
    pxs = prepare_pixels_with_segmentation(
        image=image, segmentation_maps=segmentation_maps
    )
    plt.imshow(pxs)
    plt.axis("off")

In [None]:
visualize_predictions(image=batch[0], segmentation_maps=segmentation_maps[0])

### Run sequential inference on 10 batches

Next, you will test the scalability and performance of the sequential batch inference approach by increasing the number of batches from 1 to 10. This will allow you to observe and verify that this approach can limit performance when scaling.

#### Prepare batches

In [None]:
BATCH_SIZE = 16
N_BATCHES = 10

# Get BATCH_SIZE * N_BATCHES randomly shuffled image IDs from the test dataset.
image_indices = get_image_indices(dataset=test_dataset, n=BATCH_SIZE * N_BATCHES)

# Split indices into N_BATCHES
image_indices_grouped = np.split(np.asarray(image_indices), N_BATCHES)
image_indices_grouped

In [None]:
batches = []

# Create a list of images for each batch of indices sampled from the test dataset.
for image_idx in image_indices_grouped:
    batch = [test_dataset[int(i)]["image"] for i in image_idx]
    batches.append(batch)

batches[0]

#### Run batch inference

In [None]:
predictions = []

In [None]:
for batch in batches:
    segmentation_maps = predict(
        model=segformer,
        feature_extractor=segformer_feature_extractor,
        images=batch,
    )
    predictions.append(segmentation_maps)

Notice that increasing the number of batches by 10 leads to approximately a 10x increase in runtime/ This is the expected result for a sequential approach, which scales linearly with the number of batches.

In [None]:
# Inspect the resulting segmentation maps array.
predictions[0][0]

### Summary: Sequential batch inference

|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/single_sequential_timeline.png" width="90%" loading="lazy">|
|:--|
|Timeline of sequential batch inference using a single worker. Tasks can vary in runtime due variations in complexity, data size, and more. |

#### Key concepts

<div class="alert alert-info">
  <strong>Batch inference</strong> (also known as offline inference): is the process of generating predictions on a large set or "batch" of data.
</div>

## Part 4: Distributed batch inference with Ray AIR

These high-level APIs automate the challenging aspects of parallelizing and distributing batch inference tasks, allowing you to focus on the inference logic. 

There are four main abstractions that work together to optimize this process:

* [**`Datasets`**](https://docs.ray.io/en/latest/data/dataset.html)  
    * These are used to parallelize data loading, preprocessing, and exchanging data in Ray AIR.
* [**`Checkpoint`**](https://docs.ray.io/en/latest/tune/tutorials/tune-checkpoints.html)  
    * `Checkpoint` objects represent saved models created during training or tuning and provide a common interface for restoring the model's state for tasks such as inference.
* [**`Predictor`**](https://docs.ray.io/en/latest/ray-air/predictors.html)  
    * Ray AIR `Predictors` are a class that load models from `Checkpoint` to perform inference and can be used by `BatchPredictor` to do large-scale inference.
* [**`BatchPredictor`**](https://docs.ray.io/en/latest/ray-air/predictors.html#batch-prediction)  
    * Ray AIR `BatchPredictor` utility takes in a `Checkpoint` and a `Predictor` class and executes large-scale distributed batch prediction on a Ray Dataset when calling `predict()`.

|<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Scaling_inference/air_batchpredictor.png" width="70%" loading="lazy">|
|:--|
|Ray Datasets parallelize data loading, preprocessing, and batching. Ray AIR `BatchPredictor` takes both `Checkpoint` and `Predictor` objects to call `predict()` on a Ray Dataset for distributed batch inference.|

Ray handles operations such as batching, pipelining, actor pool autoscaling, and memory management internally, so you can benefit from the scalability and ease of use of Ray AIR without needing to worry about the details of task distribution. Using these abstractions does come with some trade-offs, as you have less control over how Ray distributes the workload.

### Initialize Ray runtime

In [None]:
import ray

In [None]:
ray.init()

### Create a Ray Dataset with 160 images

In [None]:
# Get BATCH_SIZE * N_BATCHES randomly shuffled image IDs from the test dataset.
image_indices = get_image_indices(dataset=test_dataset, n=BATCH_SIZE * N_BATCHES)

# Create a list of images for the indices sampled from the test dataset.
data = [test_dataset[i]["image"] for i in image_indices]

In [None]:
# Create a Ray Dataset from the list of images to use in Ray AIR.
dataset = ray.data.from_items(data)
dataset.show(limit=3)

### Define a custom Predictor for image data

`BatchPredictor` takes in a `Checkpoint` (which will be constructed from the SegFormer model and feature extractor) and a `Predictor`. Ray AIR supports multiple framework-specific [`Predictors`](https://docs.ray.io/en/latest/ray-air/package-ref.html#predictor) such as TorchPredictor and TensorflowPredictor while also allowing for the ability to implement a [custom](https://docs.ray.io/en/latest/ray-air/predictors.html#developer-guide-implementing-your-own-predictor) one. 

Here, you will implement a custom `SemanticSegmentationPredictor`, with the same replicas and core `predict()` logic as before, but with some modifications to fit the `BatchPredictor` pattern.

In [None]:
from ray.air import Checkpoint
from ray.train.predictor import Predictor

In [None]:
class SemanticSegmentationPredictor(Predictor):
    # The constructor method initializes the class to load/cache the model and feature extractor.
    def __init__(
        self,
        model: SegformerForSemanticSegmentation,
        feature_extractor: SegformerFeatureExtractor,
    ):
        super().__init__()
        self.model = model
        self.feature_extractor = feature_extractor

    # This is the same logic as the `predict()` function defined in Part 3,
    # only with pandas DataFrames as inputs and outputs.
    def _predict_pandas(self, batch: pd.DataFrame) -> pd.DataFrame:
        # Set the device on which PyTorch will run.
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        self.model.eval()

        # The feature extractor processes raw images.
        batch = [batch["value"][0]]
        inputs = self.feature_extractor(images=batch, return_tensors="pt")

        # The model is applied to input images in the inference step.
        with torch.no_grad():
            outputs = self.model(pixel_values=inputs.pixel_values.to(device))

        # Post-process the output for display.
        image_sizes = [image.size[::-1] for image in batch]
        segmentation_maps_postprocessed = (
            self.feature_extractor.post_process_semantic_segmentation(
                outputs=outputs, target_sizes=image_sizes
            )
        )

        # Post-process the list of segmentation maps into a pandas DataFrame
        df = pd.DataFrame(columns=["segmentation_maps"])
        df.loc[0, "segmentation_maps"] = segmentation_maps_postprocessed

        return df

    # Creates an instance of SemanticSegmentationPredictor using the model and
    # feature extractor contained in the Checkpoint.
    @classmethod
    def from_checkpoint(
        self, checkpoint: Checkpoint, **kwargs
    ) -> "SemanticSegmentationPredictor":
        checkpoint_data = checkpoint.to_dict()
        return SemanticSegmentationPredictor(
            model=checkpoint_data["model"],
            feature_extractor=checkpoint_data["feature_extractor"],
        )

### Create a BatchPredictor

In [None]:
from ray.train.batch_predictor import BatchPredictor

In [None]:
# Construct a BatchPredictor using the SegFormer model and feature extractor along with an instance
# of the custom SemanticSegmentationPredictor class.
batch_predictor = BatchPredictor(
    checkpoint=Checkpoint.from_dict(
        {"model": segformer, "feature_extractor": segformer_feature_extractor}
    ),
    predictor_cls=SemanticSegmentationPredictor,
)

### Run parallel batch inference on a Ray Dataset

In [None]:
predictions_dataset = batch_predictor.predict(data=dataset, batch_size=1)

In [None]:
# Inspect the resulting segmentation maps in this DataFrame.
predictions_dataset.take(limit=1)

In [None]:
# Terminate processes started by ray.init().
ray.shutdown()

### Summary: Distributed batch inference with Ray AIR

#### Key API elements

* [**`Datasets`**](https://docs.ray.io/en/latest/data/dataset.html)  
    * These are used to parallelize data loading, preprocessing, and exchanging data in Ray AIR.
* [**`Checkpoint`**](https://docs.ray.io/en/latest/tune/tutorials/tune-checkpoints.html)  
    * `Checkpoint` objects represent saved models created during training or tuning and provide a common interface for restoring the model's state for tasks such as inference.
* [**`Predictor`**](https://docs.ray.io/en/latest/ray-air/predictors.html)  
    * Ray AIR `Predictors` are a class that load models from `Checkpoint` to perform inference and can be used by `BatchPredictor` to do large-scale inference.
* [**`BatchPredictor`**](https://docs.ray.io/en/latest/ray-air/predictors.html#batch-prediction)  
    * Ray AIR `BatchPredictor` utility takes in a `Checkpoint` and a `Predictor` class and executes large-scale distributed batch prediction on a Ray Dataset when calling `predict()`.


#### Shutdown Ray runtime

# Connect with the Ray community

You can learn and get more involved with the Ray community of developers and researchers:

* [**Ray documentation**](https://docs.ray.io/en/latest)

* [**Official Ray Website**](https://www.ray.io/)  
Browse the ecosystem and use this site as a hub to get the information that you need to get going and building with Ray.

* [**Join the Community on Slack**](https://forms.gle/9TSdDYUgxYs8SA9e8)  
Find friends to discuss your new learnings in our Slack space.

* [**Use the Discussion Board**](https://discuss.ray.io/)  
Ask questions, follow topics, and view announcements on this community forum.

* [**Join a Meetup Group**](https://www.meetup.com/Bay-Area-Ray-Meetup/)  
Tune in on meet-ups to listen to compelling talks, get to know other users, and meet the team behind Ray.

* [**Open an Issue**](https://github.com/ray-project/ray/issues/new/choose)  
Ray is constantly evolving to improve developer experience. Submit feature requests, bug-reports, and get help via GitHub issues.

* [**Become a Ray contributor**](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html)  
We welcome community contributions to improve our documentation and Ray framework.

<img src="https://technical-training-assets.s3.us-west-2.amazonaws.com/Generic/ray_logo.png" width="20%" loading="lazy">