# Scaling Inference

<img src="../../_static/assets/Generic/ray_logo.png" width="20%" loading="lazy">

## About this notebook

### Is it right for you?

This module focuses on the batch inference task. It presents common design patterns for running batch inference and present few approaches how to implement batch inference with Ray depending on your needs. It is right for you if:

* you work with model (batch) inference problems and you observe performance bottlenecks
* you want to scale or increase throughput of your existing batch inference pipelines
* you wish to explore different architectures for batch inference with Ray Core and Ray AIR

### Prerequisites

For this notebook you should have:

* practical Python and machine learning experience
* familiarity with batch inference problem in ML
* familiarity with Ray and Ray AIR. Equivalent to completing these training modules:
  * [Overview of Ray](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Overview_of_Ray.ipynb)
  * [Introduction to Ray AIR](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Introduction_to_Ray_AIR.ipynb)
  * [Ray Core](https://github.com/ray-project/ray-educational-materials/tree/main/Ray_Core)

### Learning objectives

Upon completion of this notebook, you will know about:

* batch inference patterns
* how to implement scalable batch inference with Ray

### What will you do?

* learn about scaling inference with common design patterns
* explore different architectures for predicting on semantic segmentation tasks
* implement paralleized inference through hands-on coding exercises

## Part 1: (Ray) architectures for scalable batch inference

The end goal for machine learning models is to generate performant predictions over a set of unseen data. In this module, you will approach parallelizing batch inference on using Ray Core's API as well as the high-level abstractions available in Ray AI Runtime.

|<img src="../../_static/assets/Scaling_inference/ml_workflow.png" width="70%" loading="lazy">|
|:--|
|Example of a machine learning workflow.|

### Stateless inference - Ray Tasks

Loading complex models into memory can be expensive and sequential processing of requests limits speed. *Stateless inference* allows an ML system to handle high volume requests by:

1. exporting the model's mathematical core into a language agnostic format
2. restoring the architecture and weights of a trained model in a stateless function (i.e. Ray tasks)

A Ray task is *stateless* because its output (e.g. predictions) is determined purely by its inputs (e.g. the trained model). Performing online inference involves loading the model for every request and synchronously serving results.

|<img src="../../_static/assets/Scaling_inference/task_inference.png" width="70%" loading="lazy">|
|:--|
|Stateless inference using Ray Tasks.|

In the figure above, you perform batch inference by preprocessing your big dataset into batches that are assigned to workers via Ray tasks. Each task loads the trained model and outputs predictions on batches as they are assigned.

### Stateful inference - Ray Actors

When your deployed model takes too long to generate immediate results, online prediction may not be the right approach. In addition, some situations require predictions to be generated over large volumes of data such as curating personalized playlists. You can use *batch inference*, which is an asynchronous method of batching observations for prediction in advance to process a high volume of samples efficiently.

Setting up distributed batch inference with Ray involves:

1. creating a number of replicas of your model; in Ray, these replicas are represented as Actors (i.e., stateful processes) that can be assigned to GPUs and hold instantiated model objects

2. feeding data into these model replicas in parallel, and retrieve inference results

|<img src="../../_static/assets/Scaling_inference/actor_inference.png" width="70%" loading="lazy">|
|:--|
|Stateful inference using Ray Actors|

Much like stateless inference using Ray tasks, stateful inference replaces Ray tasks with Ray actors and leverages Ray's object store to avoid loading the model for every batch.

### Ray ActorPool - Increment of the previous approach - utility lib.

Ray provides a convenient [ActorPool utility](https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-util-actorpool) which wraps the above list of actors to avoid futures management.

|<img src="../../_static/assets/Scaling_inference/actor_pool.png" width="70%" loading="lazy">|
|:--|
|Using Actor Pools for Batch Inference.|

Building off of the stateful inference diagram, an Actor Pool wraps around the `n` actors so you do not have to manage idle actors and manually distribute workloads.

### Ray AIR Datasets

Ray Datasets allows for parallel reading and preprocessing of source data along with autoscaling of the ActorPool. As a part of Ray AIR, you specify what you want done through a set of declarative key-value arguments rather than concerning yourself with how to instruct Ray to scale.

|<img src="../../_static/assets/Scaling_inference/ray_datasets.png" width="70%" loading="lazy">|
|:--|
|Ray Datasets replace the 'Batch preprocessing' stage.|

In Ray AIR, a trained model is loaded into a `Checkpoint` object (could be from training or tuning). An AIR `Predictor` loads model from the `Checkpoint` to perform inference. Then, using the preprocessed batches provided by Ray Datasets, you extract predictions off of the testing data.

### Ray AIR BatchPredictor

Ray AIR's [`BatchPredictor`](https://docs.ray.io/en/latest/ray-air/package-ref.html#batch-predictor) takes in a [`Checkpoint`](https://docs.ray.io/en/latest/ray-air/package-ref.html#checkpoint) which represents the saved model. This high-level abstraction offers simple and composable APIs that enable preprocessing data in batches with [BatchMapper](https://docs.ray.io/en/latest/ray-air/package-ref.html#generic-preprocessors) and instantiate a distributed predictor given checkpoint data.

|<img src="../../_static/assets/Scaling_inference/air_batchpredictor.png" width="70%" loading="lazy">|
|:--|
|Using Ray AIR's `BatchPredictor` for Batch Inference.|

Finally, you can use an AIR `BatchPredictor` that takes both the `Checkpoint` and `Predictor` to replace the process of manually performing inference on a large dataset.

## Part 2: Notes on data and model

### Data

Image segmentation takes a scene and classifies image objects [into semantic categories](https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit?usp=sharing) pixel-by-pixel. [MIT ADE20K Dataset](http://sceneparsing.csail.mit.edu/) (SceneParse150) provides the largest open source dataset for scene parsing, and in this notebook, you will be scaling inference on image regions depicted in these samples.

|<img src="../../_static/assets/Scaling_inference/scene.png" width="70%" loading="lazy">|
|:--|
|Test image on the left vs. predicted result on the right.[Source](https://github.com/CSAILVision/semantic-segmentation-pytorch) *Date accessed: November 10, 2022*|

**Dataset Highlights**

- 20k annotated, scene-centric training images
- 2k validation images
- 150 total categories such as person, car, bed, sky, and more

### Model

[SegFormer](https://arxiv.org/pdf/2105.15203.pdf) is a simple and powerful semantic segmentation method whose architecture consists of a hierarchical Transformer encoder and a lightweight All-MLP decoder. What sets SegFormer apart from previous approaches boils down to two key features:

1. a novel hierarchically structured Transformer encoder which does not depend on positional encoding, avoiding interpolation when test resolution differs from training
2. avoids complex decoders

With demonstrated success on benchmarks such as Cityscapes and [MIT ADE20K Dataset](http://sceneparsing.csail.mit.edu/), you will use a pretrained version to perform inference on test images from the SceneParse 150 dataset.

|<img src="../../_static/assets/Scaling_inference/segformer_architecture.png" width="70%" loading="lazy">|
|:--|
|Segformer architecture taken from [original paper](https://arxiv.org/pdf/2105.15203.pdf). *Date accessed: November 10, 2022*|


|<img src="../../_static/assets/Scaling_inference/single_seq_timeline.png" width="70%" loading="lazy">|
|:--|
|Timeline of batch inference on one worker.|

|<img src="../../_static/assets/Scaling_inference/seq_timeline.png" width="70%" loading="lazy">|
|:--|
|Timeline of sequential batch assignment spread across three workers.|

|<img src="../../_static/assets/Scaling_inference/distrib_timeline.png" width="70%" loading="lazy">|
|:--|
|Timeline of distributed bath inference where a scheduler orchestrates batch assignment as soon as a worker is available.|


## Part 3: Vanilla implementation

*(This implementation is inspired by the [Semantic segmentation](https://huggingface.co/docs/transformers/tasks/semantic_segmentation) guide. Date accessed: Nov 4th, 2022.)*

We start be downloading the SceneParse150 dataset using Hugging Face's `datasets` library. Specifically, we're going to leverage the `load_dataset` utility to reference this dataset as a string (`"scene_parse_150"`). Note that this can take a couple of minutes. We also specify a `split` argument so that we can access `train` and `test` data on the resulting dataset `ds`:

In [None]:
import torch
import numpy as np
from PIL import Image

In [None]:
torch.manual_seed(201)

### Get pre-trained model from the HuggingFace Hub

Next, we download the mappings from the Hub (using the `huggingface_hub` library) and create two dictionaries, namely `id2label` and `label2id`. We use these dictionaries to map (bidirectionally) between label IDs (int) and the labels (string) of the respective images. There are a total of 150 labels, or categories in for this dataset, and we print 10 of them at the end. Since we're parsing scenes and want to classify segments of natural images, you can see that the labels describe entities such as walls, floors, roads or grass.

In [None]:
# https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
# https://huggingface.co/datasets/huggingface/label-files/blob/main/ade20k-id2label.json
import json
from huggingface_hub import hf_hub_download

MODEL_STR = "nvidia/segformer-b0-finetuned-ade-512-512" # model finetuned on the 512x512 dataset

repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"), "r"))

id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

In [None]:
# get pretrained model
from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_STR, id2label=id2label, label2id=label2id)
print(f"number of model parameters: {model.num_parameters()/(10**6):.2f} M")

In [None]:
# feature extractor
# "reduce_labels" is to drop background from loss compute: https://huggingface.co/docs/transformers/model_doc/segformer#segformer
from transformers import SegformerFeatureExtractor
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_STR, reduce_labels=True)

### Prepare Dataset

#### Download dataset from the HuggingFace Hub

In [None]:
# Load dataset from Hugging Face
from datasets import load_dataset

DATASET = "scene_parse_150" # name of the dataset on the HuggingFace's datasets repository.

# split here only for fast-debug, remove before real use.
ds = load_dataset(DATASET, split="train[:50]")
ds

#### Display example images

In [None]:
from utils import display_example_images

In [None]:
display_example_images(ds)

Each Hugging Face dataset comes with a `train_test_split` method that we're going to use next. We want 80% of the data to be training data, and 20% held back for testing.

#### Prepare train and test splits

In [None]:
split_ds = ds.train_test_split(test_size=0.2, shuffle=False, seed=201)
train_ds = split_ds["train"]
test_ds = split_ds["test"]

train_ds

To get a feel for what this dataset consists of, let's print the first of it. Since the train-test split we did is randomized, the resulting image will be different every time you load the dataset.

### Run inference for few examples and visualize results

In [None]:
from utils import visualize_predictions

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # use GPU if available, otherwise use a CPU

def predict(model, image, labels):
    # NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass PIL images.
    # batch has shape (batch_size, num_channels, height, width).
    inputs = feature_extractor(images=image, segmentation_maps=labels, return_tensors="pt")
    outputs = model(pixel_values=inputs.pixel_values.to(device), labels=inputs.labels.to(device))
    logits = outputs.logits.cpu()
    upsampled_logits = torch.nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )
    loss = outputs.loss.cpu()
    return upsampled_logits.argmax(dim=1)[0], loss

In [None]:
j = np.random.randint(train_ds.num_rows)

random_image = train_ds[j]["image"]
labels = train_ds[j]["annotation"]

segmentation, loss = predict(model=model, image=random_image, labels=labels)

visualize_predictions(image=random_image, predictions=segmentation, loss=loss)

In [None]:
## data augmentations
from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

In [None]:
# train and validate transforms
def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = feature_extractor(images, labels)
    return inputs

def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = feature_extractor(images, labels)
    return inputs

# Inputs are pixel values - but we still need original images
# ToDo: check if transform can return images, labels, inputs

In [None]:
# run transform on data
from copy import deepcopy

original_train_ds = deepcopy(train_ds)
original_test_ds = deepcopy(test_ds)

train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

In [None]:
# eval metric from Evaluate lib.
import evaluate

metric = evaluate.load("mean_iou")

In [None]:
# compute metrics

import torch
import numpy as np

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

## Part 4: Stateless inference - Ray Tasks

In [None]:
import ray

ray.init()

The most naive version of parallelising prediction is to create Ray tasks that load the trained model internally when called. This way we can make the prediction task "stateless", but at the cost of incurring the overhead of loading the model every single time. This is akin to what serverless solutions like AWS Lambda would do, and this pattern could be worth it for tiny models, for which the application doesn't get bottle-necked by the model loading step.

In [None]:
@ray.remote
def prediction_task(image):
    model = load_trained_model()
    return predict(model, image)

prediction_refs = [prediction_task.remote(original_test_ds[i]["image"]) for i in range(10)]
predictions = ray.get(prediction_refs)

In [None]:
img = prepare_for_visualisation(original_test_ds[0]["image"], predictions[0])

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

## Part 5: Stateful inference - Ray Actors

In [None]:
import ray

@ray.remote
class PredictionActor:
    def __init__(self, model):
        self.model = model

    def predict(self, image):
        return  predict(self.model, image)

In [None]:
model = load_trained_model()
model_ref = ray.put(model)
actors = [PredictionActor.remote(model_ref) for _ in range(5)]

In [None]:
def process_result(prediction):
    print(f"Got prediction of shape: {prediction.shape}")

data = [original_test_ds[i]["image"] for i in range(10)]

# todo: check if copy is needed here
idle_actors = actors.copy()
future_to_actor = {}

# based on a pattern, where you want to wait 1,2,4,8,... seconds for the results
# based on this: https://docs.ray.io/en/master/ray-core/patterns/limit-pending-tasks.html
while data:
    if idle_actors:
        actor = idle_actors.pop()
        future = actor.predict.remote(data.pop())
        future_to_actor[future] = actor
    else:
        # https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-wait
        [ready], _ = ray.wait(list(future_to_actor.keys()), num_returns=1) # num_returns=1 makes sure that you return one item in the ready list
        actor = future_to_actor.pop(ready)
        idle_actors.append(actor)
        process_result(ray.get(ready))

# Process any leftover results at the end.
for future in future_to_actor.keys():
    process_result(ray.get(future))

## Part 6: Stateful inference - Ray ActorPool - Increment of the previous approach - utility lib.

In [None]:
from ray.util.actor_pool import ActorPool
actor_pool = ActorPool(actors)

In [None]:
data = [original_test_ds[i]["image"] for i in range(10)]


def actor_call(actor, data_item):
    return actor.predict.remote(data_item)

for result in actor_pool.map_unordered(actor_call, data):
    process_result(result)

## Part 7: Ray AIR Datasets

In [None]:
import numpy as np


class PredictionClass:
    def __init__(self, model):
        self.model = model

    def __call__(self, batch):
        prediction = predict(self.model, batch)
        # TODO: for some reason we are not allowed to return arbitrary return values here, needs to be numpy, list, etc.
        return prediction.cpu().detach().numpy()


In [None]:
dataset = ray.data.from_items(data)
dataset.show()

# ToDo: pre-processing -> dataset of baches of images, not individual images. Probably is has to be batches of pandas.

In [None]:
results = dataset.map_batches(
    PredictionClass,
    batch_size=1,
    num_gpus=0,
    compute=ray.data.ActorPoolStrategy(min_size=1, max_size=5),
    fn_constructor_args=(model_ref,)
)

results.show(1)

## Part 8: Ray AIR BatchPredictor

In [None]:
from ray.air import Checkpoint
from ray.train.predictor import Predictor
from ray.train.batch_predictor import BatchPredictor
import pandas as pd

# this is a hack, don't use the same image in the predictor.
image = original_test_ds[0]["image"]


# https://docs.ray.io/en/latest/ray-air/predictors.html#batch-prediction
# adapt it to batch prediction on images
class CustomPredictor(Predictor):
    def __init__(self, model_name):
        super().__init__()
        self.model = load_trained_model(model_name)

    def _predict_pandas(self, batch):
        # TODO: figure out how to make this run on pandas properly.
        # ... why are we forced to use pandas, though?
        prediction = predict(self.model, image)

        # ToDo: can we work with numpy here?
        # implement post processing to cast prediction to pandas

        return pd.DataFrame(prediction)

    @classmethod
    def from_checkpoint(cls, checkpoint, **kwargs):
        return CustomPredictor(checkpoint.to_dict()["model"])

predictor = BatchPredictor(
    checkpoint=Checkpoint.from_dict({"model": MODEL}),
    predictor_cls=CustomPredictor,
    preprocessor=None,
)

results = predictor.predict(dataset)