# Speculative Decoding: Train and Benchmark Medusa

Large Language Models (LLMs) are changing our world. However, productionizing them can be slow and expensive. Speculative decoding is a technique that can speed up LLM inference by predicting multiple future tokens in parallel. This can reduce the time required for generating text outputs. However, speculative decoding can be complex to implement. Medusa is a framework that simplifies the speculative decoding process while maintaining its benefits.

Medusa accelerates LLM text generation by adding multiple decoding heads to predict several subsequent tokens in parallel, instead of just the next token. It then uses tree attention to efficiently process multiple token candidates simultaneously and a typical acceptance scheme to select plausible continuations, resulting in about a 2x speedup in generation time. By integrating additional "Medusa heads" with the original model, it allows for efficient token generation without the need for a separate draft model. 

This blog post shows you how to train and benchmark Medusa. 

## Training Medusa

Before training our Medusa we need to better understand our data distribution. One of the most important things is to have a good dataset (with similar distribution to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain. 

This means if you are going to train Medusa on a dataset that is very different from the data/user queries you have in production, your speedup will be minimal or non-existent. 

There are 3 different ways to select/prepare data for training Medusa:

1. **Self-distillation**: This is the easiest and most effective way to prepare data for training. You can use the same model to generate the data that you will use to train the model. Essentially, you prompt the model with a similar input to what you will use in production and the model will generate the output.
2. **User/Application data**: If you are able to collect real user queries and model outputs, you can use this data to train Medusa. 
3. **Fine-tuning data**: If you don't have access to user data, you can use the fine-tuning dataset to train Medusa.

In this blog post, we will use the fine-tuning data to train Medusa. 

The dataset or data distribution also plays a key role when evaluating/benchmarking the performance of the Medusa heads. As we learned that Medusa has a much higher hit-rate when the generation is in-domain, it is important to evaluate the Medusa heads on the same data distribution that will be used in production or training. I

Okay lets get started. 🚀 We will use a smalle modified the [original implementation of Medusa](https://github.com/FasterDecoding/Medusa). The repository includes a training script along side a python package. Lets clone the repository and install the package.

In [None]:
# Install Pytorch, Deepspeed & Hugging Face libraries
%pip install "torch==2.4.0" tensorboard "transformers==4.44.2" "datasets==2.21.0" "accelerate==0.33.0" "deepspeed==0.14.5"
# Download and install Medusa packages
# !git clone https://github.com/philschmid/Medusa
# !cd Medusa && pip install -e .

We are going to use the same SFT dataset as used for the model. As model we will use [philschmid/code-llama-3-1-8b-text-to-sql](https://huggingface.co/philschmid/code-llama-3-1-8b-text-to-sql) a Llama 3.1 8b Q-Lora fine-tuned LLM for text-to-sql task. The SFT dataset is available at [philschmid/text-to-sql-dataset-medusa](https://huggingface.co/datasets/philschmid/text-to-sql-dataset-medusa). 

The dataset includes 10,000 training samples, which where used for SFT and 2,500 test samples. We will later use those unseen 2,500 test samples for benchmarking our edusa Model.

First Lets download our dataset and save it to a json file.

In [None]:
from datasets import load_dataset
import json

# Load the dataset
dataset = load_dataset("philschmid/text-to-sql-dataset-medusa")

# Save the train dataset as list to disk as JSON
with open("train_dataset.json", "w") as f:
    json.dump(list(dataset["train"]["messages"]), f)

Now, we are ready to train our Medusa model. For Medusa there are 2 important hyperparameters:

* `medusa_heads`: controls the number of additional decoding heads added to the language model. These heads predict multiple future tokens in parallel, speeding up inference. 
* `medusa_layers`: The number of layers to use for each Medusa heads. 

We can start our training using the existing `train_legacy.py` script and provide our parameters. 

In [None]:
!torchrun --nproc_per_node=4  Medusa/medusa/train/train_legacy.py \
    --model_name_or_path philschmid/code-llama-3-1-8b-text-to-sql \
    --data_path train_dataset.json \
    --bf16 True \
    --output_dir code_llama31 \
    --num_train_epochs 5 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --eval_strategy "no" \
    --save_strategy "epoch" \
    --learning_rate 5e-4 \
    --lr_scheduler_type "cosine" \
    --warmup_steps 40 \
    --logging_steps 10 \
    --tf32 True \
    --model_max_length 2048 \
    --lazy_preprocess False \
    --medusa_num_heads 3 \
    --medusa_num_layers 1 \
    --deepspeed ./Medusa/deepspeed.json

Nice our training successfully finished. Now we just need to push our model to the hub and we are ready to benchmark our model. Make sure to replace the `folder` and `repo` with your own values.

In [None]:
!python -m medusa.hf_utils \
    --folder code_llama31_medusa \
    --repo philschmid/code-llama-3-1-8b-text-to-sql-medusa 

## Benchmark Medusa with Hugging Face Text Generation Inference 

TODO: Text about what TGI is on what it supports 

When using Speculative decoding we want to measure the acceleration of our model through the Medusa heads. "Acceleration" refers to the speedup achieved during speculative decoding. Specifically, acceleration is calculated as the ratio of the total number of tokens (both generated and skipped) to the number of iterations or loops needed to produce those tokens.

Acceleration measures how much faster a model can produce output when it speculates multiple tokens ahead. The formula given is.  
`Acceleration = (Total Tokens (Generated + Skipped)) / (Number of Loops/Iterations)`

acceleration = (17687 + 27101) / 27101 = 1.65



Lets start our TGI container using the `docker run` command.
```bash
CUDA_VISIBLE_DEVICES=0 docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8080:80 \
  -e MODEL_ID=philschmid/code-llama-3-1-8b-text-to-sql-medusa \
  -e NUM_SHARD=1 \
  -e MAX_INPUT_TOKENS=4096 \
  -e MAX_TOTAL_TOKENS=6000 \
  ghcr.io/huggingface/text-generation-inference:sha-b70ae09
```

After our container is running, we can test it with a simple query. 

In [None]:
%%bash
curl localhost:8080/v1/chat/completions \
    -X POST \
    -d '{
  "model": "tgi",
  "messages": [
    {
      "role": "user",
      "content": "Write a poem for my three year old"
    }
  ],
  "stream": false,
  "max_tokens": 250
}' \
    -H 'Content-Type: application/json'

For Benchmarking we will use guidellm from NeuralMagic. GuideLLM can be used to simulate real-world inference workloads, GuideLLM helps users gauge the performance, resource needs, and cost implications of deploying LLMs on various hardware configurations. Supporting Hugigng Face dataset from local files or remote for benchmarking. If you are planning to use a Hugging Face dataset, you need to make sure that the dataset includes a `text` field, with the formatted prompt (system + user). 

Here is a simple example on how to create on. (It was used to create the test set):
```python
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the dataset
dataset = load_dataset("philschmid/text-to-sql-dataset-medusa")
tokenizer = AutoTokenizer.from_pretrained("philschmid/code-llama-3-1-8b-text-to-sql")

def create_text_field(samples):
    prompt = tokenizer.apply_chat_template(samples["messages"][0:2], tokenize=False)
    return {"text": prompt}
  
td = dataset["test"].map(create_text_field)
td.push_to_hub("philschmid/text-to-sql-dataset-medusa-test-chatml")
```

In [None]:
!pip install guidellm

GuideLLM will use the OpenAI API and run through our test dataset 2,500 queries with different concurrent requests. 

```bash
guidellm \
  --target "http://localhost:8080/v1" \
  --model philschmid/code-llama-3-1-8b-text-to-sql-medusa \
  --data philschmid/text-to-sql-dataset-medusa-test-chatml \
  --data-type transformers \
  --max-seconds 60 \
  --output-path benchmark_medusa.json
```

Running Llama 3.1 8b with Medusa heads on a single GPU we get the following results: 

```                                                                                                     
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓ 
┃ Benchmark                 ┃ Requests per Second ┃ Request Latency ┃ Time to First Token ┃ Inter Token Latency ┃ Output Token Throughput ┃ 
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩ 
│ synchronous               │ 0.86 req/sec        │ 1.16 sec        │ 80.81 ms            │ 50.32 ms            │ 18.53 tokens/sec        │ 
│ throughput                │ 3.91 req/sec        │ 12.48 sec       │ 10824.93 ms         │ 77.18 ms            │ 84.42 tokens/sec        │ 
└───────────────────────────┴─────────────────────┴─────────────────┴─────────────────────┴─────────────────────┴─────────────────────────┘  
```

Now, lets compare this to our base model without Medusa. There for we use the same `docker run` command but replace our Model id. 

```Bash
CUDA_VISIBLE_DEVICES=0 docker run --gpus all -ti --shm-size 1g --ipc=host --rm -p 8080:80 \
  -e MODEL_ID=philschmid/code-llama-3-1-8b-text-to-sql \
  -e NUM_SHARD=1 \
  -e MAX_INPUT_TOKENS=4096 \
  -e MAX_TOTAL_TOKENS=6000 \
  ghcr.io/huggingface/text-generation-inference:sha-b70ae09
```

And then rerun our benchmark, make sure to change the output path to `benchmark_baseline.json`. 

```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓ 
┃ Benchmark                 ┃ Requests per Second ┃ Request Latency ┃ Time to First Token ┃ Inter Token Latency ┃ Output Token Throughput ┃ 
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩ 
│ synchronous               │ 0.71 req/sec        │ 1.40 sec        │ 81.08 ms            │ 60.50 ms            │ 15.56 tokens/sec        │ 
│ throughput                │ 3.76 req/sec        │ 13.11 sec       │ 11370.37 ms         │ 81.92 ms            │ 79.66 tokens/sec        │ 
└───────────────────────────┴─────────────────────┴─────────────────┴─────────────────────┴─────────────────────┴─────────────────────────┘  
```