# Scaling Inference

<img src="../../_static/assets/Generic/ray_logo.png" width="20%" loading="lazy">

## About this notebook

### Is it right for you?

This is an introductory notebook that gives a broad overview of the Ray project. It is right for you if:

* you work with model inference problem or you want to scale your existing inference pipelines
* <ToDo\>

### Prerequisites

For this notebook you should have:

* practical Python and machine learning experience
* familiarity with Ray and Ray AIR. Equivalent to completing these training modules:
  * [Overview of Ray](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Overview_of_Ray.ipynb)
  * [Introduction to Ray AIR](https://github.com/ray-project/ray-educational-materials/blob/main/Introductory_modules/Introduction_to_Ray_AIR.ipynb)
  * [Ray Core](https://github.com/ray-project/ray-educational-materials/tree/main/Ray_Core)
* familiarity with batch inference problem in ML

### Learning objectives

Upon completion of this notebook, you will know about:

* Inference patterns
* Architectures how to scale inference with Ray

### What will you do?

<ToDo/>

## Part 1: Patterns and architectures for scalable inference

<ToDo, part 1>

ML model lifecycle - diagram with focus

### Patterns for resilient serving

Book chapters
* Stateless Serving Function - implement with Ray Tasks
* Batch Serving - implement with Actors

### Ray architectures for scalable inference

* Stateless inference - Ray Task
* Stateful inference - Ray Actors
* Ray ActorPool - Increment of the previous approach - utility lib.
* Ray AIR Datasets
* Ray AIR BatchPredictor

## Part 2: Notes on data and model

### Data

### Model

## Part 3: Vanilla implementation

*(This implementation is inspired by the [Semantic segmentation](https://huggingface.co/docs/transformers/tasks/semantic_segmentation) guide. Date accessed: Nov 4th, 2022.)*

In [None]:
# Load dataset from Hugging Face
from datasets import load_dataset

ds = load_dataset("scene_parse_150", split="train[:50]")

In [None]:
# make tr/test splits
ds = ds.train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
# print one example
train_ds[0]["image"]

In [None]:
# Download the mappings from the Hub and create the id2label and label2id dictionaries

import json
from huggingface_hub import hf_hub_url, hf_hub_download

repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"), "r"))

id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

In [None]:
## Preprocessing and augmentations
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained("nvidia/mit-b0", reduce_labels=True)

## Augmentations
from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)


In [None]:
# train and validate transforms
def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = feature_extractor(images, labels)
    return inputs

def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = feature_extractor(images, labels)
    return inputs

In [None]:
# run transform on data
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

In [None]:
# get pretrained model
from transformers import AutoModelForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0"
model = AutoModelForSemanticSegmentation.from_pretrained(
    pretrained_model_name, id2label=id2label, label2id=label2id
)

In [None]:
# instantiate TrainingArguments
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="segformer-b0-scene-parse-150",
    learning_rate=6e-5,
    num_train_epochs=5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    push_to_hub=False,
)

In [None]:
# eval metric from Evaluate lib.
import evaluate

metric = evaluate.load("mean_iou")

In [None]:
# compute metrics

import torch
import numpy as np

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [None]:
# create Trainer
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
# run model training
trainer.train()

In [None]:
# inference
# check example image from test
image = test_ds[0]["image"]
image

In [None]:
# process img
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # use GPU if available, otherwise use a CPU
encoding = feature_extractor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)


In [None]:
# inference - get logits
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()

In [None]:
# rescale img
upsampled_logits = torch.nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,
)

pred_seg = upsampled_logits.argmax(dim=1)[0]


In [None]:
# visualize pred
import matplotlib.pyplot as plt

color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[pred_seg == label, :] = color
color_seg = color_seg[..., ::-1]  # convert to BGR

img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

In [None]:
# visualize pred
import matplotlib.pyplot as plt

color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[pred_seg == label, :] = color
color_seg = color_seg[..., ::-1]  # convert to BGR

img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

## Part 4: Stateless inference - Ray Tasks

## Part 5: Stateful inference - Ray Actors

## Part 6: Stateful inference - Ray ActorPool - Increment of the previous approach - utility lib.

## Part 7: Ray AIR Datasets

## Part 8: Ray AIR BatchPredictor