# Scaling Inference

<img src="../../_static/assets/Generic/ray_logo.png" width="20%" loading="lazy">

## About this notebook

### Is it right for you?

This is an introductory notebook that gives a broad overview of the Ray project. It is right for you if:

* you work with model inference problem or you want to scale your existing inference pipelines
* <ToDo\>

### Prerequisites

For this notebook you should have:

* practical Python and machine learning experience
* familiarity with Ray and Ray AIR. Equivalent to completing these training modules:
  * [Overview of Ray](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Overview_of_Ray.ipynb)
  * [Introduction to Ray AIR](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Introduction_to_Ray_AIR.ipynb)
  * [Ray Core](https://github.com/ray-project/ray-educational-materials/tree/main/Ray_Core)
* familiarity with batch inference problem in ML

### Learning objectives

Upon completion of this notebook, you will know about:

* Inference patterns
* Architectures how to scale inference with Ray

### What will you do?

<ToDo/>

## Part 1: Patterns and architectures for scalable inference

The end goal for machine learning models is to generate performant predictions over a set of unseen data. In this module, you will approach parallelizing inference on using Ray Core's API as well as the high-level abstractions available in Ray AI Runtime.

|<img src="../../_static/assets/Scaling_inference/ml_pipeline.png" width="70%" loading="lazy">|
|:--|
|Typical end-to-end machine learning pipeline.|

### Patterns for resilient serving

**Stateless serving function with Ray tasks**

In production environments, you prioritize lower latency, however, loading complex models into memory can be expensive and sequential processing of requests limits speed. *Stateless serving* allows an ML system to handle high volume requests by:

1. exporting the model's mathematical core into a language agnostic format
2. restoring the architecture and weights of a trained model in a stateless function (i.e. Ray tasks)
3. deploying Ray tasks into a framework that provides a REST endpoint

This way, the end client doesn't need to know anything about machine learning to be able to autoscale the web endpoints or manage the web application.

**Batch serving with Ray actors**

 When your deployed model takes too long to generate immediate results, online prediction may not be the right approach. In addition, some situations require predictions to be generated over large volumes of data such as curating personalized playlists. You can use *batch inference*, which is an asynchronous method of batching observations for prediction in advance to process a high volume of samples efficiently.

Setting up distributed batch inference with Ray involves:

1. creating a number of replicas of your model; in Ray, these replicas are represented as Actors (i.e., stateful processes) that can be assigned to GPUs and hold instantiated model objects

2. feeding data into these model replicas in parallel, and retrieve inference results

### Ray architectures for scalable inference

#### Stateless inference - Ray Task

A Ray task is *stateless* because its output (e.g. predictions) is determined purely by its inputs (e.g. the trained model). Performing online inference involves loading the model for every request and synchronously serving results.

#### Stateful inference - Ray Actors

Stateless inference may work alright for smaller models, at times you want to avoid loading the model from memory every time. By placing the trained model in the Ray object store, then launching a number of Actors, you can implement distributed dispatch of tasks and process results from the actors in a streaming way.

|<img src="../../_static/assets/Scaling_inference/actors.png" width="70%" loading="lazy">|
|:--|
|Using Ray Actors for Batch Inference.|

#### Ray ActorPool - Increment of the previous approach - utility lib.

Ray provides a convenient [ActorPool utility](https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-util-actorpool) which wraps the above list of actors to avoid futures management.

|<img src="../../_static/assets/Scaling_inference/actor_pool.png" width="70%" loading="lazy">|
|:--|
|Using Actor Pools for Batch Inference.|

#### Ray AIR Datasets

Ray Datasets allows for parallel reading and preprocessing of source data along with autoscaling of the ActorPool. As a part of Ray AIR, you specify what you want done through a set of declarative key-value arguments rather than concerning yourself with how to instruct Ray to scale.

#### Ray AIR BatchPredictor

Ray AIR's [`BatchPredictor`](https://docs.ray.io/en/latest/ray-air/package-ref.html#batch-predictor) takes in a [`Checkpoint`](https://docs.ray.io/en/latest/ray-air/package-ref.html#checkpoint) which represents the saved model. This high-level abstraction offers simple and composable APIs that enable preprocessing data in batches with [BatchMapper](https://docs.ray.io/en/latest/ray-air/package-ref.html#generic-preprocessors) and instantiate a distributed predictor given checkpoint data.

|<img src="../../_static/assets/Scaling_inference/air_batchpredictor.png" width="70%" loading="lazy">|
|:--|
|Using Ray AIR for Batch Inference.|

## Part 2: Notes on data and model

### Data

Image segmentation takes a scene and classifies image objects [into semantic categories](https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit?usp=sharing) pixel-by-pixel. [MIT ADE20K Dataset](http://sceneparsing.csail.mit.edu/) provides the largest open source dataset for scene parsing, and in this notebook, you will be scaling inference on image regions depicted in these samples.

|<img src="../../_static/assets/Scaling_inference/scene.png" width="70%" loading="lazy">|
|:--|
|[From left to right: test image, ground truth, predicted result.](https://github.com/CSAILVision/semantic-segmentation-pytorch) *Date accessed: November 10, 2022*|

[**Dataset Highlights**](https://arxiv.org/pdf/1608.05442.pdf)

- 20k annotated, scene-centric training images
- 2k validation images
- 150 total categories such as person, car, bed, sky, and more

### Model

[SegFormer Paper](https://arxiv.org/pdf/2105.15203.pdf)

SegFormer consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on semantic segmentation benchmarks such as ADE20K and Cityscapes. The hierarchical Transformer is first pre-trained on ImageNet-1k, after which a decode head is added and fine-tuned altogether on a downstream dataset.

The abstract from the paper is the following:

"We present SegFormer, a simple, efficient yet powerful semantic segmentation framework which unifies Transformers with lightweight multilayer perception (MLP) decoders. SegFormer has two appealing features: 1) SegFormer comprises a novel hierarchically structured Transformer encoder which outputs multiscale features. It does not need positional encoding, thereby avoiding the interpolation of positional codes which leads to decreased performance when the testing resolution differs from training. 2) SegFormer avoids complex decoders. The proposed MLP decoder aggregates information from different layers, and thus combining both local attention and global attention to render powerful representations. We show that this simple and lightweight design is the key to efficient segmentation on Transformers. We scale our approach up to obtain a series of models from SegFormer-B0 to SegFormer-B5, reaching significantly better performance and efficiency than previous counterparts. For example, SegFormer-B4 achieves 50.3% mIoU on ADE20K with 64M parameters, being 5x smaller and 2.2% better than the previous best method. Our best model, SegFormer-B5, achieves 84.0% mIoU on Cityscapes validation set and shows excellent zero-shot robustness on Cityscapes-C."

|<img src="../../_static/assets/Scaling_inference/segformer_architecture.png" width="70%" loading="lazy">|
|:--|
|Segformer architecture taken from [original paper](https://arxiv.org/pdf/2105.15203.pdf). *Date accessed: November 10, 2022*|


## Part 3: Vanilla implementation

*(This implementation is inspired by the [Semantic segmentation](https://huggingface.co/docs/transformers/tasks/semantic_segmentation) guide. Date accessed: Nov 4th, 2022.)*

We start be downloading the SceneParse150 dataset using Hugging Face's `datasets` library. Specifically, we're going to leverage the `load_dataset` utility to reference this dataset as a string (`"scene_parse_150"`). Note that this can take a couple of minutes. We also specify a `split` argument so that we can access `train` and `test` data on the resulting dataset `ds`:

In [None]:
# Load dataset from Hugging Face
from datasets import load_dataset

ds = load_dataset("scene_parse_150", split="train[:50]")

Each Hugging Face dataset comes with a `train_test_split` method that we're going to use next. We want 80% of the data to be training data, and 20% held back for testing.

In [None]:
# make tr/test splits
split_ds = ds.train_test_split(test_size=0.2)
train_ds = split_ds["train"]
test_ds = split_ds["test"]

To get a feel for what this dataset consists of, let's print the first of it. Since the train-test split we did is randomized, the resulting image will be different every time you load the dataset.

In [None]:
train_ds[0]["image"]

Next, we download the mappings from the Hub (using the `huggingface_hub` library) and create two dictionaries, namely `id2label` and `label2id`. We use these dictionaries to map (bidirectionally) between label IDs (int) and the labels (string) of the respective images. There are a total of 150 labels, or categories in for this dataset, and we print 10 of them at the end. Since we're parsing scenes and want to classify segments of natural images, you can see that the labels describe entities such as walls, floors, roads or grass.

In [None]:
import json
from huggingface_hub import hf_hub_download

repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"), "r"))

id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
print(list(id2label.values())[:10])

In [None]:
## Preprocessing and augmentations
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained("nvidia/mit-b0", reduce_labels=True)

In [None]:
## Augmentations
from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

In [None]:
# train and validate transforms
def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = feature_extractor(images, labels)
    return inputs

def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = feature_extractor(images, labels)
    return inputs

# Inputs are pixel values - but we still need original images
# ToDo: check if transform can return images, labels, inputs

In [None]:
# run transform on data
from copy import deepcopy

original_train_ds = deepcopy(train_ds)
original_test_ds = deepcopy(test_ds)

train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

In [None]:
# get pretrained model
from transformers import AutoModelForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0"
model = AutoModelForSemanticSegmentation.from_pretrained(
    pretrained_model_name, id2label=id2label, label2id=label2id
)

In [None]:
# instantiate TrainingArguments
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="segformer-b0-scene-parse-150",
    learning_rate=6e-5,
    num_train_epochs=5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    push_to_hub=False,
)

In [None]:
# eval metric from Evaluate lib.
import evaluate

metric = evaluate.load("mean_iou")

In [None]:
# compute metrics

import torch
import numpy as np

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [None]:
# create Trainer
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
# run model training (TODO: optionally load a pre-trained model here instead, e.g. to save time).
# ToDo: shall we put model on S3 or HF?

run_training = False

if run_training:
    trainer.train()

In [None]:
if run_training:
    MODEL = trainer.save_model(".")
else:
    MODEL = pretrained_model_name

def load_trained_model(model_path=MODEL):
    return AutoModelForSemanticSegmentation.from_pretrained(
        model_path, id2label=id2label, label2id=label2id
    )

model = load_trained_model()

In [None]:
# inference
image = original_test_ds[0]["image"]
image

In [None]:
# todo - pixel values
pixels = test_ds[0]["pixel_values"]
pixels

In [None]:
# this and following cells are re-usable functions.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # use GPU if available, otherwise use a CPU

def encode_pixels(image):
    encoding = feature_extractor(image, return_tensors="pt")
    pixel_values = encoding.pixel_values.to(device)
    return pixel_values


In [None]:
def compute_logits(model, pixel_values):
    outputs = model(pixel_values=pixel_values)
    return outputs.logits.cpu()


In [None]:
# rescale img
def upsample_logits(logits):
    upsampled_logits = torch.nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    return upsampled_logits.argmax(dim=1)[0]

def predict(model, image):
    pixel_values = encode_pixels(image)
    logits = compute_logits(model, pixel_values)
    return upsample_logits(logits)

prediction = predict(model, image)
prediction

In [None]:
# visualize pred
# palette was copied from here: https://github.com/tensorflow/models/blob/3f1ca33afe3c1631b733ea7e40c294273b9e406d/research/deeplab/utils/get_dataset_colormap.py#L51
# (accessed: nov 9, 2022)

from palette import create_ade20k_label_colormap as ade_palette

# ToDo: figure out how to visualzie grid of images -> results preview for learners
# TODO: this only seems to work for the 0th image
# ToDo: make sure that both pred and visualization works here
def prepare_for_visualisation(image, prediction):
    color_seg = np.zeros((prediction.shape[0], prediction.shape[1], 3), dtype=np.uint8)
    palette = np.array(ade_palette())
    for label, color in enumerate(palette):
        color_seg[prediction == label, :] = color
    color_seg = color_seg[..., ::-1]  # convert to BGR

    img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
    img = img.astype(np.uint8)
    return img

In [None]:
import matplotlib.pyplot as plt

# image is now Jpeg object
img = prepare_for_visualisation(image, prediction)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

## Part 4: Stateless inference - Ray Tasks

In [None]:
import ray

ray.init()

The most naive version of parallelising prediction is to create Ray tasks that load the trained model internally when called. This way we can make the prediction task "stateless", but at the cost of incurring the overhead of loading the model every single time. This is akin to what serverless solutions like AWS Lambda would do, and this pattern could be worth it for tiny models, for which the application doesn't get bottle-necked by the model loading step.

In [None]:
@ray.remote
def prediction_task(image):
    model = load_trained_model()
    return predict(model, image)

prediction_refs = [prediction_task.remote(original_test_ds[i]["image"]) for i in range(10)]
predictions = ray.get(prediction_refs)

In [None]:
img = prepare_for_visualisation(original_test_ds[0]["image"], predictions[0])

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

## Part 5: Stateful inference - Ray Actors

In [None]:
import ray

@ray.remote
class PredictionActor:
    def __init__(self, model):
        self.model = model

    def predict(self, image):
        return  predict(self.model, image)

In [None]:
model = load_trained_model()
model_ref = ray.put(model)
actors = [PredictionActor.remote(model_ref) for _ in range(5)]

In [None]:
def process_result(prediction):
    print(f"Got prediction of shape: {prediction.shape}")

data = [original_test_ds[i]["image"] for i in range(10)]

# todo: check if copy is needed here
idle_actors = actors.copy()
future_to_actor = {}

# based on a pattern, where you want to wait 1,2,4,8,... seconds for the results
# based on this: https://docs.ray.io/en/master/ray-core/patterns/limit-pending-tasks.html
while data:
    if idle_actors:
        actor = idle_actors.pop()
        future = actor.predict.remote(data.pop())
        future_to_actor[future] = actor
    else:
        # https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-wait
        [ready], _ = ray.wait(list(future_to_actor.keys()), num_returns=1) # num_returns=1 makes sure that you return one item in the ready list
        actor = future_to_actor.pop(ready)
        idle_actors.append(actor)
        process_result(ray.get(ready))

# Process any leftover results at the end.
for future in future_to_actor.keys():
    process_result(ray.get(future))

## Part 6: Stateful inference - Ray ActorPool - Increment of the previous approach - utility lib.

In [None]:
from ray.util.actor_pool import ActorPool
actor_pool = ActorPool(actors)

In [None]:
data = [original_test_ds[i]["image"] for i in range(10)]


def actor_call(actor, data_item):
    return actor.predict.remote(data_item)

for result in actor_pool.map_unordered(actor_call, data):
    process_result(result)

## Part 7: Ray AIR Datasets

In [None]:
import numpy as np


class PredictionClass:
    def __init__(self, model):
        self.model = model

    def __call__(self, batch):
        prediction = predict(self.model, batch)
        # TODO: for some reason we are not allowed to return arbitrary return values here, needs to be numpy, list, etc.
        return prediction.cpu().detach().numpy()


In [None]:
dataset = ray.data.from_items(data)
dataset.show()

# ToDo: pre-processing -> dataset of baches of images, not individual images. Probably is has to be batches of pandas.

In [None]:
results = dataset.map_batches(
    PredictionClass,
    batch_size=1,
    num_gpus=0,
    compute=ray.data.ActorPoolStrategy(min_size=1, max_size=5),
    fn_constructor_args=(model_ref,)
)

results.show(1)

## Part 8: Ray AIR BatchPredictor

In [None]:
from ray.air import Checkpoint
from ray.train.predictor import Predictor
from ray.train.batch_predictor import BatchPredictor
import pandas as pd

# this is a hack, don't use the same image in the predictor.
image = original_test_ds[0]["image"]


# https://docs.ray.io/en/latest/ray-air/predictors.html#batch-prediction
# adapt it to batch prediction on images
class CustomPredictor(Predictor):
    def __init__(self, model_name):
        super().__init__()
        self.model = load_trained_model(model_name)

    def _predict_pandas(self, batch):
        # TODO: figure out how to make this run on pandas properly.
        # ... why are we forced to use pandas, though?
        prediction = predict(self.model, image)

        # ToDo: can we work with numpy here?
        # implement post processing to cast prediction to pandas

        return pd.DataFrame(prediction)

    @classmethod
    def from_checkpoint(cls, checkpoint, **kwargs):
        return CustomPredictor(checkpoint.to_dict()["model"])

predictor = BatchPredictor(
    checkpoint=Checkpoint.from_dict({"model": MODEL}),
    predictor_cls=CustomPredictor,
    preprocessor=None,
)

results = predictor.predict(dataset)