# SageMaker JumpStart Foundation Models - Inference Latency and Throughput Benchmarking

***
Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use SageMaker JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart).


In this demo notebook, we demonstrate how to run latency and throughput benchmarking analyses on a set of SageMaker JumpStart models. The structure of the notebook allows you to both benchmark a single model against multiple payloads and multiple models against a single payload. 

***

1. [Set up](#1.-Set-up)
2. [Run latency and throughput benchmarking](#2.-Run-latency-and-throughput-benchmarking)
3. [Visualize benchmarking results](#3.-Visualize-benchmarking-results)
4. [Clean up](#4.-Clean-up)

### 1. Set up

***
Before executing the notebook, there are some initial steps required for set up. 
***

In [8]:
%pip install --upgrade sagemaker ipywidgets --quiet

[0mNote: you may need to restart the kernel to use updated packages.


***
Here, you will query the SageMaker SDK to return a list of all HuggingFace text generation (and text2text) models hosted by SageMaker Model Hub. You can manually select any combination of these models to run benchmarking on with the Jupyter Widget produced in the output of this cell. By default, only a few models are selected.
***

In [9]:
from ipywidgets import SelectMultiple, Layout
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And, Or

# Retrieves all Text Generation models available by SageMaker Built-In Algorithms.
tasks = ["textgeneration", "textgeneration1", "textgeneration2", "text2text"]
filter_value = And(Or(*[f"task == {task}" for task in tasks]), "framework == huggingface")
text_models = list_jumpstart_models(filter=filter_value)
selected_text_models = [
    "huggingface-textgeneration-falcon-40b-instruct-bf16",
    "huggingface-textgeneration-falcon-40b-bf16",
    "huggingface-textgeneration-falcon-7b-instruct-bf16",
    "huggingface-textgeneration-falcon-7b-bf16",
    # "huggingface-text2text-flan-t5-xxl",
    # "huggingface-textgeneration1-gpt-j-6b",
    # "huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16",
    # "huggingface-textgeneration-bloom-1b7",
]
# if you would like to run on all JumpStart LLMs instead, uncomment the following line.
# selected_text_models = text_models.copy()
# selected_text_models.remove("huggingface-textgeneration1-bloom-176b-int8")
# selected_text_models.remove("huggingface-textgeneration1-bloomz-176b-fp16")

models_selection = SelectMultiple(
    options=text_models,
    value=selected_text_models,
    description="Models:",
    rows=25,
    layout=Layout(width="100%"),
)
display(models_selection)

SelectMultiple(description='Models:', index=(27, 26, 29, 28), layout=Layout(width='100%'), options=('huggingfa…

***
In the following cell, you will select the models and payloads to benchmark. Every payload will be benchmarked against every model.
- **MODELS**: A list of SageMaker JumpStart model IDs to run benchmarking against.
- **PAYLOADS**: A dictionary with keys identifying a unique name for a query payload and values containing a valid payload dictionary.
***

In [10]:
MODELS = models_selection.value

PAYLOADS = {
    "simple_short_input": {
        "text_inputs": "Hello!",
        "do_sample": True,
        "max_new_tokens": 25,
    },
    "generate_summary": {
        "text_inputs": (
            "Write a short summary for this text: Amazon Comprehend uses natural language "
            "processing (NLP) to extract insights about the content of documents. It develops "
            "insights by recognizing the entities, key phrases, language, sentiments, and other "
            "common elements in a document. Use Amazon Comprehend to create new products based on "
            "understanding the structure of documents. For example, using Amazon Comprehend you "
            "can search social networking feeds for mentions of products or scan an entire "
            "document repository for key phrases. \nYou can access Amazon Comprehend document "
            "analysis capabilities using the Amazon Comprehend console or using the Amazon "
            "Comprehend APIs. You can run real-time analysis for small workloads or you can start "
            "asynchronous analysis jobs for large document sets. You can use the pre-trained "
            "models that Amazon Comprehend provides, or you can train your own custom models for "
            "classification and entity recognition. \nAll of the Amazon Comprehend features "
            "accept UTF-8 text documents as the input. In addition, custom classification and "
            "custom entity recognition accept image files, PDF files, and Word files as input. \n"
            "Amazon Comprehend can examine and analyze documents in a variety of languages, "
            "depending on the specific feature. For more information, see Languages supported in "
            "Amazon Comprehend. Amazon Comprehend's Dominant language capability can examine "
            "documents and determine the dominant language for a far wider selection of languages."
        ),
        "do_sample": True,
        # "max_length": 500,
        "max_new_tokens": 25,
    },
}

***
The following set of constants drive the behavior of this notebook:
- **MAX_CONCURRENT_INVOCATIONS_PER_MODEL**: The maximum number of endpoint predictions to request concurrently.
- **MAX_CONCURRENT_BENCHMARKS**: The maximum number of models to concurrently benchmark.
- **RETRY_WAIT_TIME_SECONDS**: The amount of time in seconds to wait between Amazon CloudWatch queries. This is necessary because the endpoint emits CloudWatch metrics on a periodic interval, so we need to wait until all samples are emitted to CloudWatch before publishing benchmarking statistics.
- **MAX_TOTAL_RETRY_TIME_SECONDS**: The maximum amount of time in seconds to wait on Amazon CloudWatch emissions before proceeding without collecting the requested benchmarking metrics.
- **NUM_INVOCATIONS**: The number of endpoint predictions to request per benchmark.
- **SAVE_METRICS_FILE_PATH**: The JSON file used to save the resulting metrics.
- **SM_SESSION**: SageMaker Session object with custom configuration to resolve [SDK rate exceeded and throttling exceptions](https://aws.amazon.com/premiumsupport/knowledge-center/sagemaker-python-throttlingexception/).
***

In [11]:
from pathlib import Path


MAX_CONCURRENT_BENCHMARKS = 50
NUM_INVOCATIONS = 10
SAVE_METRICS_FILE_PATH = Path.cwd() / "latency_benchmarking.json"

### 2. Run latency and throughput benchmarking

***

The following block defines a function to run benchmarking on a single SageMaker JumpStart model ID. This function performs the following actions:
- Create a SageMaker JumpStart `Model` object.
- Deploy the Model and obtain a `Predictor`.
- Run all benchmarking load tests for each payload defined in the `PAYLOADS` dictionary. The benchmarking process includes:
  - Obtain latency statistics - serially invoke an endpoint to obtain a batch of predictions and utilize the Amazon CloudWatch [GetMetricStatistics](https://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_GetMetricStatistics.html) API to obtain latency statistics regarding the batch of predictions. The endpoint is invoked `NUM_INVOCATIONS` times.
  - Obtain throughput statistics - concurrently invoke an endpoint to obtain client-side throughput statistics. The endpoint is invoked `NUM_INVOCATIONS` times.
- Clean up predictor model and endpoint. If any errors occur during the benchmarking process for a given model, this clean up process still occurs prior to raising the error.

***

***
In the following block, the `run_benchmarking` function is called for all model IDs specified within the previously defined `MODELS` list. To avoid a serial deployment process, the Python standard library [concurrent futures](https://docs.python.org/3/library/concurrent.futures.html) module is used to concurrently execute a `MAX_CONCURRENT_BENCHMARKS` number of executor threads. When a thread completes execution, the computed metrics are extended into a single list. If any thread raises an error instead of returning metrics, the errors are recorded in a dictionary without re-raising the error. This allows benchmarking to continue for all other models.
***

In [12]:
from benchmarking.runner import Benchmarker


benchmarker = Benchmarker(payloads=PAYLOADS, max_concurrent_benchmarks=MAX_CONCURRENT_BENCHMARKS)
metrics, errors = benchmarker.run_multiple_model_ids(models=MODELS)

(Model 'huggingface-textgeneration-falcon-40b-instruct-bf16'): Deploying endpoint jumpstart-bm-hf-textgeneration-falcon-4-2023-06-09-04-52-32-299 ...
(Model 'huggingface-textgeneration-falcon-7b-bf16'): Deploying endpoint jumpstart-bm-hf-textgeneration-falcon-7-2023-06-09-04-52-32-372 ...
(Model 'huggingface-textgeneration-falcon-40b-bf16'): Deploying endpoint jumpstart-bm-hf-textgeneration-falcon-4-2023-06-09-04-52-32-379 ...
(Model 'huggingface-textgeneration-falcon-7b-instruct-bf16'): Deploying endpoint jumpstart-bm-hf-textgeneration-falcon-7-2023-06-09-04-52-32-385 ...
-----------------------------------------------!(Model huggingface-textgeneration-falcon-7b-bf16) Benchmarking failed: run_single_predictor() missing 1 required positional argument: 'payloads'
-----------------!(Model huggingface-textgeneration-falcon-7b-instruct-bf16) Benchmarking failed: run_single_predictor() missing 1 required positional argument: 'payloads'
--------------------------!(Model huggingface-textgener

***
Finally, we save these benchmarked metrics to a JSON file for use in downstream analyses.
***

In [13]:
import json


output = {"models": MODELS, "payloads": PAYLOADS, "metrics": metrics}
with open(SAVE_METRICS_FILE_PATH, "w") as file:
    json.dump(output, file, indent=4, ensure_ascii=False)

### 3. Visualize benchmarking results

***
The saved JSON results are now re-loaded into a normalized pandas DataFrame for visualization. This cell shows the following:
1. The column names of the DataFrame. These are the available statistics you are able to explore.
2. A table that shows a sample output from each model ID in `MODELS` for each payload in `PAYLOAD`.
3. A table that shows key latency and throughput statistics for each model ID in `MODELS` and each payload in `PAYLOAD`.
***

In [14]:
import pandas as pd


pd.set_option("display.max_colwidth", 0)
pd.set_option("display.max_rows", 500)

df = pd.json_normalize(metrics)
print("Here are the available statistics: ", list(df.columns))

index_cols = ["PayloadName", "ModelID"]
display_cols = ["PayloadName", "ModelID", "SampleOutput"]
sort_cols = ["PayloadName"]
display(df[display_cols].sort_values(by=sort_cols).set_index(index_cols))

display_cols = [
    "PayloadName",
    "ModelID",
    "Throughput",
    "ModelLatency.Average",
    "Client.Latency.Average",
    "Client.OutputSequenceWords.Average",
    "WordThroughput",
    "Client.LatencyPerOutputWord.Average",
]
sort_cols = ["PayloadName", "Client.LatencyPerOutputWord.Average"]
display(df[display_cols].sort_values(by=sort_cols).set_index(index_cols).round(3))

Here are the available statistics:  []


KeyError: "None of [Index(['PayloadName', 'ModelID', 'SampleOutput'], dtype='object')] are in the [columns]"

***
Finally, we show some plots based on this latency analysis. For each payload, this cell creates a plotly figure that plots the average latency per output word versus word throughput, or the number of words in output sequences returned per second by the model. In general, throughput = 1 / latency. However, multi-model endpoints and load-balanced endpoints can improve throughput for a fixed latency. Both of these are important metrics to consider when designing requirements for model selection.
***

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np


for payload_name in PAYLOADS:
    col_x, col_y = "WordThroughput", "Client.LatencyPerOutputWord.Average"
    df_plot = df[df["PayloadName"] == payload_name]
    fig = px.scatter(df_plot, x=col_x, y=col_y, hover_data=["ModelID"])
    fig.add_trace(
        go.Scatter(x=np.linspace(1, 300, 300), y=1 / np.linspace(1, 300, 300), name="y=1/x")
    )
    fig.update_layout(
        xaxis_range=[0.0, df_plot[col_x].max() * 1.1],
        yaxis_range=[0.0, df_plot[col_y].max() * 1.1],
        title=f"Latency per word vs. word throughput for payload {payload_name}",
    )
    fig.show()

### 4. Clean up

***
When you are done with the endpoints, you should delete them to avoid additional costs. In this demonstration, clean up occurs at the end of each individual benchmarking job.
***