# Speculative Decoding Tutorial
Speculative decoding is a technique used to speed up the generation of tokens from a generative deep learning model. The main idea is to leverage a smaller model often referred to as the "draft model" to generate K draft tokens auto-regressively. The target model i.e. the model that we actually want to use for generation then scores these K tokens in parallel, and we accept or reject them based.

Speculative decoding also guarantees that the final output matches the one that would have been obtained by using the target model alone, ensuring correctness while improving speed when we use greedy decoding. When we use sampling based decoding strategies like top-k or nucleus sampling etc. speculative decoding (often referred to as speculative sampling in this context) guarantees to reflect the same probability distribution as the target model. **This means that you're guaranteed to get the same outputs as those produced by sampling the target model alone, regardless of decoding strategy.**

This blog/notebook shows how to implement speculative decoding using PyTorch and the Hugging Face Transformers library without using a pre-built function for speculative decoding.

Let's install the necessary libraries and set up the environment for speculative decoding.

In [3]:
!uv add transformers accelerate hf_xet

[2mResolved [1m75 packages[0m [2min 3.41s[0m[0m
[2mPrepared [1m1 package[0m [2min 370ms[0m[0m
[2mInstalled [1m1 package[0m [2min 27ms[0m[0m
 [32m+[39m [1mhf-xet[0m[2m==1.2.0[0m


In [1]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

  from .autonotebook import tqdm as notebook_tqdm


## Draft and Target Models
Let's load the draft and target models. In this blog we will use SmolLM2-360M as our draft model and SmolLM2-1.7B as our target model. One of the key requirements for speculative decoding is that the draft model should:

1. Be significantly smaller than the target model to ensure faster token generation.
2. Uses the same tokenizer as the target model. This ensures that there's a 1-1 mapping between the tokens generated by the draft model and those scored by the target model.

_Note: There have been a few algorithms that reconcile the differences in tokenizers between the draft and target models, but they are out of scope for this blog. For more on this checkout [Universal Assisted Generation: Faster Decoding with Any Assistant Model](https://huggingface.co/blog/universal_assisted_generation)._

In [None]:
draft_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct", device_map="auto")
draft_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct")

target_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto")
target_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")

Let's create a couple of prompts to test speculative decoding. Both the prompts guarantee that a decent number of tokens need to be generated and can also easily be verified for correctness. The knowledge intensive nature of the prompts also ensures that the draft model is likely to make mistakes, allowing us to see how speculative decoding handles rejections.

In [None]:
prompts = [
    "The 50 states of the USA in alphabetical order are: ", 
    "The countries of South America in alphabetical order are: ",
]

In this tutorial we'll use the draft model to generate 8 tokens. After every 8 tokens generated by the draft model, the target model will score them in parallel and accept or reject them based.

We'll also just focus on greedy decoding for now i.e. the draft model will always pick the token with the highest probability at each step. 

In [None]:
import torch
import platform
target_model.eval()
num_draft_tokens = 8
greedy_gen = GenerationConfig(num_beams=1, do_sample=False, num_draft_tokens=8)
if torch.cuda.is_available():
  device_type = "cuda"
elif platform.system() == "Darwin" and getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
  device_type = "mps"
else:
  device_type = "cpu"

Some parameters are on the meta device because they were offloaded to the disk.


Let's sample a prompt and encode them using draft and target model tokenizers. We'll compare the encoded tokens to ensure that both tokenizers produce the same tokens for the same prompt as a way to verify that both models use the same tokenizer.

In [None]:
import random
prompt = random.sample(prompts, 1)[0]
inputs = target_tokenizer.encode(prompt, return_tensors="pt").to(device_type)
draft_inputs = draft_tokenizer.encode(prompt, return_tensors="pt").to(device_type)
assert torch.equal(inputs, draft_inputs)

Let's run the prompt through both the draft and target models to see what the outputs from both models look like. We can observe that both the models have slightly different outputs. Remember that our implementation of speculative decoding, should ensure that the final output matches that of the target model exactly.

In [None]:
draft_output_tokens = draft_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256, use_cache=True))[0]
draft_tokenizer.decode(draft_output_tokens)

'The 50 states of the USA in alphabetical order are: \nAlabama, Alaska, Arizona, Arkansas, California, Colorado, Connecticut, Delaware, Florida, Georgia, Hawaii, Idaho, Illinois, Indiana, Iowa, Kansas, Kentucky, Louisiana, Maine, Maryland, Massachusetts, Michigan, Minnesota, Mississippi, Missouri, Montana, Nebraska, Nevada, New Hampshire, New Jersey, New Mexico, Ohio, Oklahoma, Oregon, Pennsylvania, Rhode Island, South Carolina, Tennessee, Texas, Utah, Vermont, Virginia, Washington, West Virginia, Wisconsin, and Wyoming.<|im_end|>'

In [None]:
target_output_tokens = target_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256, use_cache=True))[0]
target_tokenizer.decode(target_output_tokens)

'The 50 states of the USA in alphabetical order are: \nAlabama, Alaska, Arizona, Arkansas, California, Colorado, Connecticut, Delaware, Florida, Georgia, Hawaii, Idaho, Illinois, Indiana, Iowa, Kansas, Kentucky, Louisiana, Maine, Maryland, Massachusetts, Michigan, Minnesota, Mississippi, Missouri, Montana, Nebraska, Nevada, New Hampshire, New Jersey, New Mexico, New York, North Carolina, North Dakota, Ohio, Oklahoma, Oregon, Pennsylvania, Rhode Island, South Carolina, South Dakota, Tennessee, Texas, Utah, Vermont, Virginia, Washington, West Virginia, Wisconsin, Wyoming.<|im_end|>'

In [None]:
assert not torch.equal(draft_output_tokens,target_output_tokens)

To gauge potential speedups from speculative decoding, let's use `%%timeit` to see how long it takes the draft and target models to produce a response for our prompt.

In [None]:
%%timeit
target_output_tokens = target_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256, use_cache=True))[0]

21.9 s ± 730 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
draft_output_tokens = draft_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256, use_cache=True))[0]

2.47 s ± 104 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Implementing Speculative Decoding
In this tutorial we'll assume a batch size of 1 for simplicity.

Here are the steps to the algorithm:

1. Generate K draft tokens using the draft model.
2. Run a single forward pass of the target model to obtain the probability scores of tokens at each of the K positions.
3. Check if the token corresponding to the highest probability score assigned by the target model at position i (1 <= i <= K) matches the draft token at position i.
4. Identify the first position j where the draft token does not match the target model's highest probability token. Steps 3 and 4. check if there's a discrepancy between the greedy outputs of the draft and target models.
5. After the previous step we know that:
    - Draft tokens 1 to j-1 would have been generated by the target model as well, so we can accept these tokens.
    - Despite draft token j being wrong, we know what the correct token should've been (i.e. the token with the highest probability score assigned by the target model at position j).
   So we can prepare out next input sequence by appending the accepted tokens (1 to j-1) and the correct token at position j to our existing input sequence.
6. If all K tokens were accepted, we can simply append all K tokens to our input sequence and also append the next token predicted by the target model at position K+1.
7. Repeat the process until the desired sequence length is reached or a stop token is generated.


#### Why can't we accept tokens at j+1 from the target model?

Let's assume that our prompt is `Jack and Jill` and we generate 3 draft tokens using the draft model: `went down a`.

1. Target model tokens with highest prob scores:
   - Position 1: `went` (matches draft)
   - Position 2: `up` (does not match draft `down`)
   - Position 3: `to` (does not match draft `a`)
2. We accept the token `up` at position 2 because the prefix `Jack and Jill went` would have been generated by the target model as well and auto-regressive generation for the next token depends only on the prefix.
3. However, we cannot accept the token `to` at position 3 because the prefix at this position when the model was used for scoring would have been `Jack and Jill went down` since the input to the target model comes from the draft model. So the input/prefix to the target model at position 3 is different from what it would have been if we had generated tokens auto-regressively using the target model alone.

#### No forward pass through the target model gets wasted
If no draft tokens were accepeted, we still know what the correct token should be at position 1 after the forward pass. So we can append this token to our input sequence and move forward. We already discussed the case for when j or all K tokens are accepted and how we can always append one token from the target model after each iteration.

In [None]:
def speculative_decoding(prompt, max_new_tokens, gen_config) -> torch.LongTensor:
    do_stop = False
    inputs = target_tokenizer.encode(prompt, return_tensors="pt").to(device_type)
    original_prompt_len = inputs.shape[1]
    draft_output = None
    target_output = None
    while not do_stop:
        prompt_len = inputs.shape[1]
        # print(f"Original prompt length is {prompt_len}")
        draft_output = draft_model.generate(
            inputs=inputs,
            past_key_values=draft_output.past_key_values.to_tuple() if draft_output else None,
            generation_config=gen_config,
            return_dict_in_generate=True
        )[:, prompt_len: ]
        draft_tokens = draft_output.sequences
        validation_inputs = torch.cat([inputs, draft_tokens], dim=1)
        with torch.no_grad():
            target_output = target_model(
                validation_inputs,
                use_cache=True,
                return_dict=True,
                past_key_values=target_output.past_key_values.to_tuple() if target_output else None
            )
            logits = target_output.logits[:, prompt_len-1:]
            probs = torch.nn.functional.softmax(logits, dim=-1)
            model_predicted_tokens = torch.argmax(probs, dim=-1)

            # Get probs of last non-pad token only
        # draft_token_probs = torch.gather(
        #     probs[:, :-1],
        #     dim=-1,
        #     index=draft_tokens.view(draft_tokens.shape[1], -1),
        # )
        if not torch.all(model_predicted_tokens[:, :-1] == draft_tokens):
            #print(f"Generated draft tokens {draft_tokenizer.batch_decode(draft_tokens)}")
            mismatch = torch.argwhere(model_predicted_tokens[:, :-1] != draft_tokens)[0][1]
            mismatched_token = model_predicted_tokens[0, mismatch]
            matched_draft_tokens = draft_tokens[:, :mismatch]
            #print(f"Draft model predicted: {(draft_tokens[0, mismatch] ,tokenizer.decode([draft_tokens[0, mismatch]], skip_special_tokens=False))}")
            #print(f"Model predicted : {(mismatched_token, tokenizer.decode([mismatched_token], skip_special_tokens=False))}")
            inputs =  torch.cat([inputs, matched_draft_tokens, mismatched_token.unsqueeze(0).unsqueeze(0)], dim=1)
        else:
            inputs =  torch.cat([inputs, draft_tokens, model_predicted_tokens[0, -1].unsqueeze(0).unsqueeze(0)], dim=1)

        if target_tokenizer.eos_token_id in inputs or (max_new_tokens - (inputs.shape[1] - original_prompt_len)) < 0:
            do_stop=True
        #print(f"Next inputs are : {tokenizer.batch_decode(inputs)}")
    final_answer = inputs[0, : min(torch.argwhere(inputs[0]==target_tokenizer.eos_token_id)+1, max_new_tokens)]
    return final_answer

In [36]:
speculative_decoding_output_tokens = speculative_decoding(prompt, 256)

In [37]:
assert torch.equal(output_tokens, speculative_decoding_output_tokens)