# Speculative Decoding: Train and Benchmark Medusa

Large Language Models (LLMs) are changing our world. However, productionizing them can be slow and expensive. Speculative decoding is a technique that can speed up LLM inference by predicting multiple future tokens in parallel. This can reduce the time required for generating text outputs. However, speculative decoding can be complex to implement. Medusa is a framework that simplifies the speculative decoding process while maintaining its benefits.

Medusa accelerates LLM text generation by adding multiple decoding heads to predict several subsequent tokens in parallel, instead of just the next token. It then uses tree attention to efficiently process multiple token candidates simultaneously and a typical acceptance scheme to select plausible continuations, resulting in about a 2x speedup in generation time. By integrating additional "Medusa heads" with the original model, it allows for efficient token generation without the need for a separate draft model. 

This blog post shows you how to train and benchmark Medusa. 

_Note: This examples was run on aws `g6e.12xlarge` with 4x NVIDIA L40S GPUs with each 48GB Memory._

## Training Medusa

Before training our Medusa we need to better understand our data distribution. One of the most important things is to have a good dataset (with similar distribution to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain. 

This means if you are going to train Medusa on a dataset that is very different from the data/user queries you have in production, your speedup will be minimal or non-existent. 

There are 3 different ways to select/prepare data for training Medusa:

1. **Self-distillation**: This is the easiest and most effective way to prepare data for training. You can use the same model to generate the data that you will use to train the model. Essentially, you prompt the model with a similar input to what you will use in production and the model will generate the output.
2. **User/Application data**: If you are able to collect real user queries and model outputs, you can use this data to train Medusa. 
3. **Fine-tuning data**: If you don't have access to user data, you can use the fine-tuning dataset to train Medusa.

In this blog post, we will use the Self-distillation data to train Medusa. 

The dataset or data distribution also plays a key role when evaluating/benchmarking the performance of the Medusa heads. As we learned that Medusa has a much higher hit-rate when the generation is in-domain, it is important to evaluate the Medusa heads on the same data distribution that will be used in production or training. I

Okay lets get started. 🚀 We will use a smalle modified the [original implementation of Medusa](https://github.com/FasterDecoding/Medusa). The repository includes a training script along side a python package. Lets clone the repository and install the package.

In [None]:
# Install Pytorch, Deepspeed & Hugging Face libraries
%pip install "torch==2.4.0" tensorboard "transformers==4.44.2" "datasets==2.21.0" "accelerate==0.33.0" "deepspeed==0.14.5"
# Download and install Medusa packages
# !git clone https://github.com/philschmid/Medusa
# !cd Medusa && pip install -e .

### Creating Self-distillation data

To create the self-distillation data we will use the same model that we will use to train Medusa. In this blog post we will use [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) and prompts from the Open-Orca dataset. We use Hugging Face TGI to start an endpoint and asynchronously iterate over the dataset to generate our self-distilled responses. We are going to create 15,000 examples for training Medusa.

```bash
docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8080:80 \
  -e MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct \
  -e NUM_SHARD=4 \
  -e MAX_INPUT_TOKENS=7168 \
  -e MAX_TOTAL_TOKENS=8192 \
  -e HF_TOKEN=$(cat ~/.cache/huggingface/token) \
  ghcr.io/huggingface/text-generation-inference:sha-b70ae09
```

In [None]:
import asyncio
import json 
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from datasets import load_dataset

# max concurrency
sem = asyncio.Semaphore(64)

# Initialize the client using the Hugging Face Inference API
client = AsyncOpenAI(
    base_url="http://localhost:8080/v1",
    api_key="-",
)

# Generate Response for input message
async def gen_response(sample):
    # Comment in if you want to see the prompt
    try:
        conv = sample["messages"][:-1]
        response = await client.chat.completions.create(
            model="tgi",
            messages=conv,
            temperature=0,
            max_tokens=1024,
        )
        results = response.choices[0].message.content
        conv.append({"role": "assistant", "content": results})
        # Add the evaluation results to the sample
        return conv
    except Exception as e:
        print(e)
        return None

# Combined async helper method to handle concurrent scoring and
async def gen_data(dataset):
    async def _gen(sample):
        async with sem:
            res = await gen_response(sample)
            progress_bar.update(1)
            return res

    progress_bar = tqdm_asyncio(total=len(dataset), desc="Generating self-distillation", unit="sample")
    tasks = [_gen(text) for text in dataset]
    responses = await tqdm_asyncio.gather(*tasks)
    progress_bar.close()
    # filter out the None responses
    responses = [r for r in responses if r is not None]
    # save the responses to a file
    with open(f"self_distil_train.json", "w") as f:
        json.dump(responses, f)
    return responses

# Load the dataset and select a 15,000k subset
dataset = load_dataset("philschmid/slimorca-deduped-cleaned-corrected-chatml")
sub_data = dataset["train"].shuffle(seed=42).select(range(15000))

# generate the responses
await gen_data(sub_data)

## Train speculative Medusa Heads

Now, we are ready to train our Medusa model. For Medusa there are 2 important hyperparameters:

* `medusa_heads`: controls the number of additional decoding heads added to the language model. These heads predict multiple future tokens in parallel, speeding up inference. 
* `medusa_layers`: The number of layers to use for each Medusa heads. 

We can start our training using the existing `train_legacy.py` script and provide our parameters. 

In [None]:
!torchrun --nproc_per_node=4  Medusa/medusa/train/train_legacy.py \
    --model_name_or_path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --data_path self_distil_train.json \
    --bf16 True \
    --output_dir llama31_instruct \
    --num_train_epochs 5 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --eval_strategy "no" \
    --save_strategy "epoch" \
    --learning_rate 1e-3 \
    --lr_scheduler_type "cosine" \
    --warmup_steps 40 \
    --logging_steps 10 \
    --tf32 True \
    --model_max_length 2048 \
    --lazy_preprocess False \
    --medusa_num_heads 3 \
    --medusa_num_layers 1 \
    --deepspeed ./Medusa/deepspeed.json

Nice our training successfully finished. Now we just need to push our model to the hub and we are ready to benchmark our model. Make sure to replace the `folder` and `repo` with your own values.

In [None]:
# remove the checkpoints for faster upload
!rm -rf llama31_instruct_medusa/checkpoint*
# push model to the hub
!huggingface-cli upload philschmid/llama-3-1-8b-instruct-medusa llama31_instruct_medusa --repo-type=model  

## Benchmark Medusa with Hugging Face Text Generation Inference 

TODO: Text about what TGI is on what it supports 

When using Speculative decoding we want to measure the acceleration of our model through the Medusa heads. "Acceleration" refers to the speedup achieved during speculative decoding. Specifically, acceleration is calculated as the ratio of the total number of tokens (both generated and skipped) to the number of iterations or loops needed to produce those tokens.

Acceleration measures how much faster a model can produce output when it speculates multiple tokens ahead. The formula given is.  
`Acceleration = (Total Tokens (Generated + Skipped)) / (Number of Loops/Iterations)`

acceleration = (17687 + 27101) / 27101 = 1.65

Lets start our TGI container using the `docker run` command. Make sure to replace the `MODEL_ID` with your own model.
```bash
CUDA_VISIBLE_DEVICES=0 docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8080:80 \
  -e MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct \
  -e NUM_SHARD=1 \
  -e MAX_INPUT_TOKENS=10000 \
  -e MAX_TOTAL_TOKENS=12444 \
  -e HF_TOKEN=$(cat ~/.cache/huggingface/token) \
  ghcr.io/huggingface/text-generation-inference:sha-b70ae09
```

After our container is running, we can test it with a simple query. 

In [4]:
from datasets import load_dataset 
from transformers import AutoTokenizer

dataset = load_dataset ("Open-Orca/OpenOrca")["train"]
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

seq_len = lambda x: len (tokenizer.encode(x["question"], add_special_tokens=False))
dataset_filtered = dataset.filter(lambda x: 500 < seq_len (x) < 1000, num_proc=32)

def create_conversation(x):
    messages = []
    if x["system_prompt"] is not None:
        messages.append({"role": "system", "content": x["system_prompt"]})
    else:
        messages.append({"role": "system", "content": "You are a helpful assistant."})
    
    messages.append({"role": "user", "content": x["question"]})
  
    return {"messages": messages}

dataset_subset = dataset_filtered.map(create_conversation)
dataset_subset = dataset_subset.shuffle(42).select(range(10000))
dataset_subset = dataset_subset.rename_column("question", "text")
dataset_subset = dataset_subset.remove_columns(["system_prompt", "response"])



Map:   0%|          | 0/267893 [00:00<?, ? examples/s]

In [5]:
dataset_subset[0]

{'id': 't0.216446',
 'text': 'Read the following article and answer the question. Article: "I say, I\'m pleased to see you," said the little man standing by the letter-box. "Oh, hello," I said, remembering he was a new neighbor. "Simpson, isn\'t it?" "Yes, that\'s right." He seemed quite pleased by my ready recognition. "I wonder if you could lend me some money," he continued. "My wife gave me a letter to post, and I\'ve just noticed it isn\'t stamped." "yes, they never are," I said, sympathetically . "It must go tonight--it really must! I\'d get stamps out of the machine," explained Simpson," Only I find I have no small change about me." "I\'m sorry, but I\'m afraid I haven\'t either," I said. "Oh, dear, dear," he said. "Yes, well," I said, intending to move off. But he looked so unhappy standing there with the blue unstamped envelope that I really hadn\'t the heart to desert him. So I took him to my house and found some pennies and gave them to him, who, in the most business like way

In [6]:
dataset_subset.push_to_hub("philschmid/open-orca-10k-guidellm")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/philschmid/open-orca-10k-guidellm/commit/d950c66ec3fad23c0dc5da9318aad3d9f8b2d20a', commit_message='Upload dataset', commit_description='', oid='d950c66ec3fad23c0dc5da9318aad3d9f8b2d20a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/philschmid/open-orca-10k-guidellm', endpoint='https://huggingface.co', repo_type='dataset', repo_id='philschmid/open-orca-10k-guidellm'), pr_revision=None, pr_num=None)

In [None]:
curl localhost:8080/v1/completions \
    -X POST \
    -d '{
  "model": "tgi",
  "prompt": "hello",
  "stream": false,
  "max_tokens": 250
}' \
    -H 'Content-Type: application/json'

In [None]:
%%bash
curl localhost:8080/v1/chat/completions \
    -X POST \
    -d '{
  "model": "tgi",
  "messages": [
    {
      "role": "user",
      "content": "Write a poem for my three year old"
    }
  ],
  "stream": false,
  "max_tokens": 250
}' \
    -H 'Content-Type: application/json'

For Benchmarking we will use guidellm from NeuralMagic. GuideLLM can be used to simulate real-world inference workloads, GuideLLM helps users gauge the performance, resource needs, and cost implications of deploying LLMs on various hardware configurations. Supporting Hugigng Face dataset from local files or remote for benchmarking. If you are planning to use a Hugging Face dataset, you need to make sure that the dataset includes a `text` field, with the formatted prompt (system + user). 

Here is a simple example on how to create on. (It was used to create the test set):
```python
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the dataset
dataset = load_dataset("philschmid/text-to-sql-dataset-medusa")
tokenizer = AutoTokenizer.from_pretrained("philschmid/code-llama-3-1-8b-text-to-sql")

def create_text_field(samples):
    prompt = tokenizer.apply_chat_template(samples["messages"][0:2], tokenize=False)
    return {"text": prompt}
  
td = dataset["test"].map(create_text_field)
td.push_to_hub("philschmid/text-to-sql-dataset-medusa-test-chatml")
```

In [None]:
dataset[0]

In [3]:
import time
from openai import OpenAI
from tqdm import tqdm
from datasets import load_dataset
from statistics import median

dataset = load_dataset("philschmid/open-orca-250-guidellm",split="train")
# Set up OpenAI API (make sure to use your actual API key)
client = OpenAI(api_key="-",base_url="http://localhost:8080/v1")

def measure_performance(client, dataset):
    start_time = time.time()
    successful_requests = 0
    results = []
    for sample in tqdm(dataset):
        has_ttft = False
        token_count = 0
        start_request = time.time()
        response = client.chat.completions.create(
            model="tgi",
            messages=[{"role": "user", "content": sample["text"]}],
            # messages=sample["messages"][0:2],
            max_tokens=250,
            temperature=0.0,
            stream=True
        )
        start_response = None
        start_request2 = time.time()
        for chunk in response:
            if start_response is None:
                start_response = time.time()
            if not has_ttft:
                ttft = time.time() - start_request
                has_ttft = True
                # print(f"Time to First Token: {ttft:.2f} seconds")
            token_count += 1
        end_request = time.time()
        itl = (end_request - start_response) / (token_count-1) * 1000
        # print(f"Output Token Throughput: {token_count / end_request:.2f} tokens/sec")
        results.append({"time_to_first_token": ttft, "output_token_throughput": token_count / (end_request - start_request2), "inter_token_latency": itl, "generated_tokens": token_count})
    return results

results = measure_performance(client, dataset)

print(f"Avg. Time to First Token: {sum([r['time_to_first_token'] for r in results]) / len(results):.2f} seconds")
print(f"Avg. Inter-Token Latency: {sum([r['inter_token_latency'] for r in results]) / len(results):.2f} ms/token")
print(f"Avg. Output Token Throughput: {sum([r['output_token_throughput'] for r in results]) / len(results):.2f} tokens/sec")
print(f"Avg. Generated Tokens: {sum([r['generated_tokens'] for r in results]) / len(results):.2f} tokens/sec")

# Calculate p50 (median) for each metric
p50_ttft = median([r["time_to_first_token"] for r in results])
p50_itl = median([r["inter_token_latency"] for r in results if r["inter_token_latency"] is not None])
p50_throughput = median([r["output_token_throughput"] for r in results])
p50_generated_tokens = median([r["generated_tokens"] for r in results])

# Print the p50 values
print(f"p50 Time to First Token: {p50_ttft:.2f} seconds")
print(f"p50 Inter-Token Latency: {p50_itl:.2f} ms/token")
print(f"p50 Output Token Throughput: {p50_throughput:.2f} tokens/sec")
print(f"p50 Generated Tokens: {p50_generated_tokens:.2f} tokens/sec")


100%|██████████| 250/250 [12:48<00:00,  3.07s/it]

Avg. Time to First Token: 0.14 seconds
Avg. Inter-Token Latency: 28.75 ms/token
Avg. Output Token Throughput: 33.28 tokens/sec
Avg. Generated Tokens: 105.40 tokens/sec
p50 Time to First Token: 0.09 seconds
p50 Inter-Token Latency: 28.59 ms/token
p50 Output Token Throughput: 33.43 tokens/sec
p50 Generated Tokens: 86.00 tokens/sec





docker run --gpus all --shm-size 1g --ipc=host --rm -p 8080:8000 \
    -e HUGGING_FACE_HUB_TOKEN=$(cat ~/.cache/huggingface/token) \
    vllm/vllm-openai:latest \
    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
    --tensor-parallel-size 8

  -e MODEL_ID=text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa \
CUDA_VISIBLE_DEVICES=0 
docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8080:80 \
  -e MODEL_ID=mistralai/Mixtral-8x7B-Instruct-v0.1 \
  -e NUM_SHARD=8 \
  -e MAX_INPUT_TOKENS=4000 \
  -e MAX_TOTAL_TOKENS=4096 \
  -e HF_TOKEN=$(cat ~/.cache/huggingface/token) \
  ghcr.io/huggingface/text-generation-inference:2.3.1

In [10]:
import asyncio
import time
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from datasets import load_dataset
from transformers import AutoTokenizer
from statistics import median

# max concurrency
concurrency = 4
sem = asyncio.Semaphore(concurrency)

# Initialize the client using the Hugging Face Inference API
client = AsyncOpenAI(
    base_url="http://localhost:8080/v1",
    api_key="-",
)

# Generate Response for input message
async def gen_response(sample):
    # Comment in if you want to see the prompt
    try:
        prompt = tokenizer.apply_chat_template(sample["messages"], tokenize=False)
        input_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
        has_ttft = False
        token_count = 0
        start_request = time.time()
        response = await client.chat.completions.create(
            model="mistralai/Mixtral-8x7B-Instruct-v0.1",
            messages=sample["messages"],
            temperature=0,
            max_tokens=250,
            stream=True
        )
        start_response = None
        async for chunk in response:
            if start_response is None:
                start_response = time.time()
            if not has_ttft:
                ttft = time.time() - start_request
                has_ttft = True
                # print(f"Time to First Token: {ttft:.2f} seconds")
            token_count += 1
        end_request = time.time()
        itl = (end_request - start_response) / (token_count-1) * 1000
        throughput = token_count / (end_request - start_response)
        # print(f"Output Token Throughput: {token_count / end_request:.2f} tokens/sec")
        return {"time_to_first_token": ttft, "output_token_throughput": throughput, "inter_token_latency": itl, "generated_tokens": token_count, "input_tokens": input_tokens}
    except Exception as e:
        print(e)
        return None

# Combined async helper method to handle concurrent scoring and
async def gen_data(dataset):
    async def _gen(sample):
        async with sem:
            res = await gen_response(sample)
            progress_bar.update(1)
            return res

    progress_bar = tqdm_asyncio(total=len(dataset), desc="Running", unit="sample")
    tasks = [_gen(text) for text in dataset]
    responses = await tqdm_asyncio.gather(*tasks)
    progress_bar.close()
    return responses

# Load the dataset and select a 15,000k subset
# dataset = load_dataset("philschmid/slimorca-deduped-cleaned-corrected-chatml")
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
# dataset = load_dataset("Aeala/ShareGPT_Vicuna_unfiltered", split="train")

# def get_prompt(x):
#     for c in x["conversations"]:
#         if c["from"] == "human":
#             return {"messages": [{"role": "user", "content": c["value"]}]}
        
    
    
    
# dataset = dataset.map(get_prompt, remove_columns=["conversations"])
# seq_len = lambda x: len(tokenizer.encode(x["messages"][0]["content"], add_special_tokens=False))
# dataset_filtered = dataset.filter(lambda x: 500 < seq_len(x) < 1000, num_proc=32)
# dataset = dataset_filtered.shuffle(seed=42).select(range(500))
dataset = load_dataset("philschmid/open-orca-10k-guidellm", split="train")
dataset = dataset.select(range(500))
# generate the responses
results = await gen_data(dataset)

# print(f"Avg. Time to First Token: {sum([r['time_to_first_token'] for r in results]) / len(results):.2f} seconds")
# print(f"Avg. Inter-Token Latency: {sum([r['inter_token_latency'] for r in results]) / len(results):.2f} ms/token")
# print(f"Avg. Output Token Throughput: {sum([r['output_token_throughput'] for r in results]) / len(results):.2f} tokens/sec")
# print(f"Avg. Generated Tokens: {sum([r['generated_tokens'] for r in results]) / len(results):.2f} tokens")

# Calculate p50 (median) for each metric
p50_ttft = median([r["time_to_first_token"] for r in results if r["time_to_first_token"] is not None])
p50_itl = median([r["inter_token_latency"] for r in results if r["inter_token_latency"] is not None])
p50_throughput = median([r["output_token_throughput"] for r in results if r["output_token_throughput"] is not None]) 
p50_generated_tokens = median([r["generated_tokens"] for r in results if r["generated_tokens"] is not None])
p50_input_tokens = median([r["input_tokens"] for r in results if r["input_tokens"] is not None])

# Print the p50 values
print(f"concurrency: {concurrency}")
print(f"p50 Time to First Token: {p50_ttft:.2f} seconds")
print(f"p50 Inter-Token Latency: {p50_itl:.2f} ms/token")
print(f"p50 Output Token Throughput: {p50_throughput:.2f} tokens/second/user")
print(f"p50 Input Tokens: {p50_input_tokens:.2f} tokens")
print(f"p50 Generated Tokens: {p50_generated_tokens:.2f} tokens")

Running:   0%|          | 0/500 [00:00<?, ?sample/s]No chat template is set for this tokenizer, falling back to a default class-level template. This is very error-prone, because models are often trained with templates different from the class default! Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which point any code depending on them will stop working. We recommend setting a valid chat template before then to ensure that this model continues working without issues.
100%|██████████| 500/500 [12:04<00:00,  1.45s/it]1.67s/sample]
Running: 100%|██████████| 500/500 [12:04<00:00,  1.45s/sample]

concurrency: 4
p50 Time to First Token: 0.07 seconds
p50 Inter-Token Latency: 41.96 ms/token
p50 Output Token Throughput: 24.03 tokens/second/user
p50 Input Tokens: 591.00 tokens
p50 Generated Tokens: 120.00 tokens





# TGI 2.3.1

```bash
concurrency: 4
p50 Time to First Token: 0.36 seconds
p50 Inter-Token Latency: 45.13 ms/token
p50 Output Token Throughput: 22.25 tokens/second/user
p50 Input Tokens: 715.00 tokens
p50 Generated Tokens: 250.00 tokens
```

orca mit prefix caching 2 runs
```bash
concurrency: 4
p50 Time to First Token: 0.07 seconds
p50 Inter-Token Latency: 41.96 ms/token
p50 Output Token Throughput: 24.03 tokens/second/user
p50 Input Tokens: 591.00 tokens
p50 Generated Tokens: 120.00 tokens
```

# TGI 2.2.0 
```bash
concurrency: 4
p50 Time to First Token: 0.39 seconds
p50 Inter-Token Latency: 45.58 ms/token
p50 Output Token Throughput: 22.03 tokens/second/user
p50 Input Tokens: 715.00 tokens
p50 Generated Tokens: 250.00 tokens
```

# vLLM latest
```bash
concurrency: 4
p50 Time to First Token: 0.38 seconds
p50 Inter-Token Latency: 44.82 ms/token
p50 Output Token Throughput: 22.40 tokens/second/user
p50 Input Tokens: 715.00 tokens
p50 Generated Tokens: 251.00 tokens
```


## Text-to-SQL Dataset

base:
```
Avg. Time to First Token: 0.08 seconds  
Avg. Output Token Throughput: 16.31 tokens/sec  
```

medusa
```
Avg. Time to First Token: 0.08 seconds
Avg. Output Token Throughput: 18.43 tokens/sec
```

## Open Orca Medusa Mistral

base:
```
Avg. Time to First Token: 0.11 seconds  
Avg. Output Token Throughput: 17.08 tokens/sec  
```

medusa
```
Avg. Time to First Token: 0.12 seconds
Avg. Output Token Throughput: 22.01 tokens/sec
```

## Open Orca Llama 3

base
```
Avg. Time to First Token: 0.05 seconds
Avg. Inter-Token Latency: 21.82 ms/token
Avg. Output Token Throughput: 44.26 tokens/sec
```

medusa
```
Avg. Time to First Token: 0.05 seconds
Avg. Inter-Token Latency: 22.12 ms/token
Avg. Output Token Throughput: 43.76 tokens/sec
``

In [None]:
!pip install guidellm

GuideLLM will use the OpenAI API and run through our test dataset 2,500 queries with different concurrent requests. 

```bash
guidellm \
  --target "http://localhost:8080/v1" \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --data philschmid/open-orca-250-guidellm \
  --data-type transformers \
  --rate-type constant \
  --rate 1 --rate 2 --rate 4 --rate 8 --rate 16 --rate 32 --rate 64 \
  --max-seconds 120 \
  --output-path benchmark_base.json
```

Running Llama 3.1 8b with Medusa heads on a single GPU we get the following results: 

```                                                                                                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓ 
┃ Benchmark                  ┃ Requests per Second ┃ Request Latency ┃ Time to First Token ┃ Inter Token Latency ┃ Output Token Throughput ┃ 
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩ 
│ asynchronous@1.00 req/sec  │ 0.94 req/sec        │ 7.34 sec        │ 452.72 ms           │ 68.03 ms            │ 94.99 tokens/sec        │ 
│ asynchronous@2.00 req/sec  │ 0.54 req/sec        │ 14.06 sec       │ 3934.82 ms          │ 65.29 ms            │ 83.25 tokens/sec        │ 
│ asynchronous@4.00 req/sec  │ 1.19 req/sec        │ 33.81 sec       │ 26663.24 ms         │ 68.54 ms            │ 123.47 tokens/sec       │ 
│ asynchronous@8.00 req/sec  │ 1.22 req/sec        │ 49.76 sec       │ 41699.83 ms         │ 68.86 ms            │ 142.06 tokens/sec       │ 
│ asynchronous@16.00 req/sec │ 1.22 req/sec        │ 50.24 sec       │ 41898.69 ms         │ 69.09 ms            │ 147.20 tokens/sec       │   
│ asynchronous@32.00 req/sec │ 0.89 req/sec        │ 64.67 sec       │ 59062.20 ms         │ 46.43 ms            │ 106.99 tokens/sec       │ 
│ asynchronous@64.00 req/sec │ 0.92 req/sec        │ 49.96 sec       │ 39648.46 ms         │ 69.34 ms            │ 136.81 tokens/sec       │ 
└────────────────────────────┴─────────────────────┴─────────────────┴─────────────────────┴─────────────────────┴─────────────────────────┘   
llama 
│ │ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ │ │
│ │ ┃              ┃ Requests per ┃ Request      ┃ Time to      ┃ Inter Token  ┃ Output Token ┃ │ │
│ │ ┃ Benchmark    ┃ Second       ┃ Latency      ┃ First Token  ┃ Latency      ┃ Throughput   ┃ │ │
│ │ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩ │ │
│ │ │ asynchronou… │ 0.98 req/sec │ 4.23 sec     │ 53.15 ms     │ 25.27 ms     │ 161.42       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 1.21 req/sec │ 15.45 sec    │ 5803.01 ms   │ 32.03 ms     │ 364.71       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 1.35 req/sec │ 7.02 sec     │ 2633.77 ms   │ 26.28 ms     │ 225.47       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 2.17 req/sec │ 22.94 sec    │ 15782.63 ms  │ 35.00 ms     │ 442.87       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 2.39 req/sec │ 39.75 sec    │ 35062.38 ms  │ 34.10 ms     │ 328.25       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 2.41 req/sec │ 15.14 sec    │ 9682.72 ms   │ 31.59 ms     │ 416.41       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 2.55 req/sec │ 26.56 sec    │ 18058.08 ms  │ 36.50 ms     │ 593.67       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ └──────────────┴──────────────┴──────────────┴──────────────┴──────────────┴──────────────┘ │ │
```

Now, lets compare this to our base model without Medusa. There for we use the same `docker run` command but replace our Model id. 

```Bash
CUDA_VISIBLE_DEVICES=0 docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8080:80 \
  -e MODEL_ID=philschmid/llama-3-1-8b-instruct-medusa \
  -e NUM_SHARD=1 \
  -e MAX_INPUT_TOKENS=4096 \
  -e MAX_TOTAL_TOKENS=6000 \
  -e HF_TOKEN=$(cat ~/.cache/huggingface/token) \
  ghcr.io/huggingface/text-generation-inference:sha-b70ae09
```

And then rerun our benchmark, make sure to change the output path to `benchmark_baseline.json`. 

``` 
│ │ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ │ │
│ │ ┃              ┃ Requests per ┃ Request      ┃ Time to      ┃ Inter Token  ┃ Output Token ┃ │ │
│ │ ┃ Benchmark    ┃ Second       ┃ Latency      ┃ First Token  ┃ Latency      ┃ Throughput   ┃ │ │
│ │ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩ │ │
│ │ │ asynchronou… │ 0.95 req/sec │ 4.62 sec     │ 55.36 ms     │ 28.74 ms     │ 151.05       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 1.63 req/sec │ 14.91 sec    │ 8740.49 ms   │ 24.01 ms     │ 416.56       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 1.84 req/sec │ 11.89 sec    │ 8128.82 ms   │ 28.95 ms     │ 239.15       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 1.87 req/sec │ 17.96 sec    │ 15337.43 ms  │ 28.47 ms     │ 171.53       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 2.08 req/sec │ 15.62 sec    │ 11674.92 ms  │ 21.92 ms     │ 370.78       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 2.68 req/sec │ 22.02 sec    │ 18314.22 ms  │ 27.13 ms     │ 364.16       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ │ asynchronou… │ 3.21 req/sec │ 20.85 sec    │ 16561.06 ms  │ 29.53 ms     │ 465.71       │ │ │
│ │ │ req/sec      │              │              │              │              │ tokens/sec   │ │ │
│ │ └──────────────┴──────────────┴──────────────┴──────────────┴──────────────┴──────────────┘ 
```

lm_eval --model vllm \
  --tasks gsm8k_cot_llama \
  --model_args pretrained=meta-llama/Meta-Llama-3.1-8B-Instruct,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=4  \
  --apply_chat_template \
  --batch_size auto \
  --fewshot_as_multiturn

# Evaluation Harness with OpenAI APIs

Evaluate `Meta-Llama-3.1-8B-Instruct` on the `gsm8k_cot_llama` task using the OpenAI API. The [expected performance for GMS8K is around ~0.84](https://ai.meta.com/blog/meta-llama-3-1/). [API Docs](https://github.com/EleutherAI/lm-evaluation-harness/blob/88ea85b4e54d0554e6051da71e30bf955a614954/docs/API_guide.md?plain=1#L29)

Installation:
```
pip install lm_eval[ifeval, hf_transfer] openai
```

Tasks: 
* [gms8k](https://github.com/EleutherAI/lm-evaluation-harness/tree/88ea85b4e54d0554e6051da71e30bf955a614954/lm_eval/tasks/gsm8k) 
* [Asdiv](https://github.com/EleutherAI/lm-evaluation-harness/blob/88ea85b4e54d0554e6051da71e30bf955a614954/lm_eval/tasks/asdiv/README.md)
* [ifeval](https://github.com/EleutherAI/lm-evaluation-harness/blob/88ea85b4e54d0554e6051da71e30bf955a614954/lm_eval/tasks/ifeval/README.md)

_Note: Use this task with `--fewshot_as_multiturn` and `--apply_chat_template`` to run correctly with Llama Instruct models._

  --tasks gsm8k_cot_llama,asdiv_cot_llama,ifeval \
```
lm_eval --model local-chat-completions \
  --tasks ifeval \
  --model_args model=meta-llama/Meta-Llama-3.1-8B-Instruct,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=32,max_retries=3,tokenized_requests=False \
  --apply_chat_template \
  --fewshot_as_multiturn
```

### TGI (2.2.0 sha-b70ae09)

Run command
``` 
CUDA_VISIBLE_DEVICES=0 docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8000:80 \
  -e MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct \
  -e NUM_SHARD=1 \
  -e MAX_INPUT_TOKENS=10000 \
  -e MAX_TOTAL_TOKENS=12444 \
  -e HF_TOKEN=$(cat ~/.cache/huggingface/token) \
  --entrypoint /bin/bash \
  ghcr.io/huggingface/text-generation-inference:sha-b70ae09
```

Result: 
| Tasks           | Version | Filter           | n-shot | Metric                  |   | Value  |   | Stderr |
|-----------------|---------|------------------|--------|-------------------------|---|--------|---|--------|
| gsm8k_cot_llama | 3       | flexible-extract | 8      | exact_match             | ↑ | 0.2206 | ± | 0.0114 |
|                 |         | strict-match     | 8      | exact_match             | ↑ | 0.2153 | ± | 0.0113 |
| asdiv_cot_llama | 1       | flexible-extract | 8      | exact_match             | ↑ | 0.1961 | ± | 0.0083 |
|                 |         | strict-match     | 8      | exact_match             | ↑ | 0.1931 | ± | 0.0082 |
| ifeval          | 4       | none             | 0      | inst_level_loose_acc    | ↑ | 0.6439 | ± | N/A    |
|                 |         | none             | 0      | inst_level_strict_acc   | ↑ | 0.6295 | ± | N/A    |
|                 |         | none             | 0      | prompt_level_loose_acc  | ↑ | 0.5508 | ± | 0.0214 |
|                 |         | none             | 0      | prompt_level_strict_acc | ↑ | 0.5342 | ± | 0.0215 |

Runtime: 
* IFEval: 3:24 min
* GSM8K: 06:08 min
* Asdiv: 10:27 min

## TGI (sha-ce85efa)


Run command
``` 
docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8000:80 \
  -e MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct \
  -e NUM_SHARD=4 \
  -e MAX_INPUT_TOKENS=10000 \
  -e MAX_TOTAL_TOKENS=12444 \
  -e HF_TOKEN=$(cat ~/.cache/huggingface/token) \
  ghcr.io/huggingface/text-generation-inference:2.2.0
```


Result: 
|     Tasks     |Version|     Filter     |n-shot|        Metric         |   |Value |   |Stderr|
|---------------|------:|----------------|-----:|-----------------------|---|-----:|---|------|
|asdiv_cot_llama|      1|flexible-extract|     8|exact_match            |↑  |0.8243|±  |0.0079|
|               |       |strict-match    |     8|exact_match            |↑  |0.8213|±  |0.0080|
|gsm8k_cot_llama|      3|flexible-extract|     8|exact_match            |↑  |0.8537|±  |0.0097|
|               |       |strict-match    |     8|exact_match            |↑  |0.8506|±  |0.0098|
|ifeval         |      4|none            |     0|inst_level_loose_acc   |↑  |0.8501|±  |   N/A|
|               |       |none            |     0|inst_level_strict_acc  |↑  |0.8189|±  |   N/A|
|               |       |none            |     0|prompt_level_loose_acc |↑  |0.7837|±  |0.0177|
|               |       |none            |     0|prompt_level_strict_acc|↑  |0.7412|±  |0.0188|


Runtime: 17:54 min


### vLLM (0.6.1)

Run command
```
CUDA_VISIBLE_DEVICES=0 docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8000:8000 \
    -e HUGGING_FACE_HUB_TOKEN=$(cat ~/.cache/huggingface/token) \
    vllm/vllm-openai:v0.6.1 \
    --model meta-llama/Meta-Llama-3.1-8B-Instruct --dtype auto
```

Result:
|     Tasks     |Version|     Filter     |n-shot|        Metric         |   |Value |   |Stderr|
|---------------|------:|----------------|-----:|-----------------------|---|-----:|---|------|
|asdiv_cot_llama|      1|flexible-extract|     8|exact_match            |↑  |0.8221|±  |0.0080|
|               |       |strict-match    |     8|exact_match            |↑  |0.8178|±  |0.0080|
|gsm8k_cot_llama|      3|flexible-extract|     8|exact_match            |↑  |0.8529|±  |0.0098|
|               |       |strict-match    |     8|exact_match            |↑  |0.8431|±  |0.0100|
|ifeval         |      4|none            |     0|inst_level_loose_acc   |↑  |0.8561|±  |   N/A|
|               |       |none            |     0|inst_level_strict_acc  |↑  |0.8201|±  |   N/A|
|               |       |none            |     0|prompt_level_loose_acc |↑  |0.7967|±  |0.0173|
|               |       |none            |     0|prompt_level_strict_acc|↑  |0.7468|±  |0.0187|

Runtime: 
* IFEval: 3:35 min
* GSM8K: 04:04 min
* Asdiv: 06:05 min
