# Demo: Speculative Decoding - A Step-by-Step Look at the Logic

**Welcome!**

In this demo, we'll pull back the curtain on **Speculative Decoding**, one of the most clever techniques for speeding up LLM inference. We won't build a full, optimized loop, but instead, we'll walk through a **single, detailed step** to understand the core logic.

**The Two Players:**
1.  **The Draft Model (`gpt2`):** A small, fast model. Think of it as a scout that runs ahead and quickly suggests a path.
2.  **The Target Model (`gpt2-medium`):** A larger, more accurate, but slower model. This is the general who verifies the scout's path.

**Our Goal:** To see how the general (target model) can efficiently verify the scout's suggestions (draft tokens) and accept multiple tokens for the cost of just one slow operation.

## 1. Environment Setup

First, let's install the necessary libraries.

In [1]:
!pip install transformers torch accelerate



## 2. Imports and Configuration

Now, we'll import our libraries and configure the demo. We'll set the names for our two models and decide how many tokens our fast draft model should propose in this step (`K_DRAFT_TOKENS`).

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

TARGET_MODEL_NAME = "gpt2-medium"
DRAFT_MODEL_NAME = "gpt2"
K_DRAFT_TOKENS = 5 # Let's have the draft model propose 5 tokens

INITIAL_CONTEXT_TEXT = "There are different ways to optimize LLM inference. One"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## 3. Load the Models (Our "Players")

Let's load both the small draft model and the larger target model. We'll also load the tokenizer, which is the same for both `gpt2` and `gpt2-medium`.

In [3]:
tokenizer = AutoTokenizer.from_pretrained(DRAFT_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Loading Target Model ('The General'): {TARGET_MODEL_NAME}...")
target_model = AutoModelForCausalLM.from_pretrained(TARGET_MODEL_NAME).to(device)
target_model.eval() # Set to evaluation mode

print(f"Loading Draft Model ('The Scout'): {DRAFT_MODEL_NAME}...")
draft_model = AutoModelForCausalLM.from_pretrained(DRAFT_MODEL_NAME).to(device)
draft_model.eval() # Set to evaluation mode

print("Models and tokenizer loaded.")

Loading Target Model ('The General'): gpt2-medium...
Loading Draft Model ('The Scout'): gpt2...
Models and tokenizer loaded.


## 4. The Speculative Decoding Step

Let's begin our detailed, single-step walkthrough.

### Step 4.1: The Scout Runs Ahead (Draft Generation)

First, our small, fast `gpt2` draft model takes the current context and quickly generates `K` candidate tokens. This happens one by one, but each step is very fast.

In [4]:
print(f"--- Step 1: Draft Model generates {K_DRAFT_TOKENS} candidate tokens ---")
current_context_ids = tokenizer.encode(INITIAL_CONTEXT_TEXT, return_tensors="pt").to(device)

with torch.no_grad():
    # Use the draft model's generate function for simplicity
    draft_output_ids = draft_model.generate(
        current_context_ids,
        max_new_tokens=K_DRAFT_TOKENS,
        pad_token_id=tokenizer.eos_token_id
    )

# Isolate just the newly generated draft tokens
draft_candidate_ids = draft_output_ids[:, current_context_ids.shape[1]:]

print(f"Initial Context: '{INITIAL_CONTEXT_TEXT}'")
print(f"Draft Model's Proposal: '{tokenizer.decode(draft_candidate_ids[0], skip_special_tokens=True)}'")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


--- Step 1: Draft Model generates 5 candidate tokens ---
Initial Context: 'There are different ways to optimize LLM inference. One'
Draft Model's Proposal: ' is to use a simple'


### Step 4.2: The General Reviews the Plan (Target Model Verification)

Now for the clever part. We take the **original context + the draft tokens** and feed this entire sequence to the large target model in a **single forward pass**.

Because Transformers process all tokens in parallel, this single pass gives us the target model's prediction for what token should come after *each* position in the input. This is far more efficient than calling the target model K times.

In [5]:
print(f"--- Step 2: Target Model verifies the draft in a single pass ---")
# Combine context and draft for the verification input
verification_input_ids = torch.cat([current_context_ids, draft_candidate_ids], dim=1)

with torch.no_grad():
    target_verification_logits = target_model(verification_input_ids).logits

# Get the target model's top-1 prediction for each position
# Note: The logits at index `t-1` predict the token at index `t`
# So we look at logits from the start of the draft sequence onwards
start_of_draft_in_logits = current_context_ids.shape[1] - 1
end_of_draft_in_logits = verification_input_ids.shape[1] - 1

target_preferred_ids = torch.argmax(target_verification_logits[:, start_of_draft_in_logits:end_of_draft_in_logits, :], dim=-1)

print(f"Verification Input: '{tokenizer.decode(verification_input_ids[0])}'")
print(f"Target Model's Preferences: '{tokenizer.decode(target_preferred_ids[0], skip_special_tokens=True)}'")

--- Step 2: Target Model verifies the draft in a single pass ---
Verification Input: 'There are different ways to optimize LLM inference. One is to use a simple'
Target Model's Preferences: ' is to use a single'


### Step 4.3: The Verdict (Comparison and Acceptance)

Now we compare the scout's path with the general's preferred path, one step at a time. We stop as soon as we find a mismatch.

| Position | Draft Model's Token | Target Model's Preference | Result   |
|----------|-----------------------|-----------------------------|----------|
| 1        | ...                   | ...                         | Match?   |
| 2        | ...                   | ...                         | Match?   |
| ...      | ...                   | ...                         | ...      |

In [6]:
print("--- Step 3: Comparing draft against target preferences ---")
num_matched_tokens = 0
for i in range(draft_candidate_ids.shape[1]):
    draft_token = draft_candidate_ids[0, i]
    target_token = target_preferred_ids[0, i]
    
    print(f"Pos {i+1}: Draft ('{tokenizer.decode(draft_token)}') vs Target ('{tokenizer.decode(target_token)}')")
    
    if draft_token == target_token:
        print("  ✅ Match!")
        num_matched_tokens += 1
    else:
        print("  ❌ Mismatch! Halting comparison.")
        break

print(f"\nNumber of matched tokens: {num_matched_tokens}")

--- Step 3: Comparing draft against target preferences ---
Pos 1: Draft (' is') vs Target (' is')
  ✅ Match!
Pos 2: Draft (' to') vs Target (' to')
  ✅ Match!
Pos 3: Draft (' use') vs Target (' use')
  ✅ Match!
Pos 4: Draft (' a') vs Target (' a')
  ✅ Match!
Pos 5: Draft (' simple') vs Target (' single')
  ❌ Mismatch! Halting comparison.

Number of matched tokens: 4


### Step 4.4: The Outcome (Constructing the Final Output)

Based on the comparison, we can now form our final output for this step. The rule is:

**`final_tokens = [all matched tokens] + [the target's 'correct' token at the mismatch point]`**

If all tokens matched, the second part is simply the next token the target model would have generated anyway.

In [7]:
print("--- Step 4: Constructing the final accepted sequence for this step ---")

# 1. Take all the tokens that matched
accepted_ids = draft_candidate_ids[0, :num_matched_tokens]

# 2. Take the target's token at the next position 
# (This is either the correction at the mismatch point, or the next token if all matched)
if num_matched_tokens < target_preferred_ids.shape[1]:
    next_token = target_preferred_ids[0, num_matched_tokens].unsqueeze(0)
    final_accepted_ids = torch.cat([accepted_ids, next_token], dim=0)
else: # This case is rare, means we ran out of target preferences
    final_accepted_ids = accepted_ids

print(f"Matched Tokens Accepted: '{tokenizer.decode(accepted_ids)}'")
if num_matched_tokens < target_preferred_ids.shape[1]:
    print(f"Correction/Extension Token: '{tokenizer.decode(next_token)}'")

# Update our full context
new_context_ids = torch.cat([current_context_ids, final_accepted_ids.unsqueeze(0)], dim=1)

print("\n--- SUMMARY OF THIS STEP ---")
print(f"Tokens generated this step: {len(final_accepted_ids)} -> '{tokenizer.decode(final_accepted_ids)}'")
print(f"Target Model expensive calls: 1")
print(f"New Context: '{tokenizer.decode(new_context_ids[0])}'")

--- Step 4: Constructing the final accepted sequence for this step ---
Matched Tokens Accepted: ' is to use a'
Correction/Extension Token: ' single'

--- SUMMARY OF THIS STEP ---
Tokens generated this step: 5 -> ' is to use a single'
Target Model expensive calls: 1
New Context: 'There are different ways to optimize LLM inference. One is to use a single'


## 5. Conclusion

And that's the core loop! In this single step, we generated **multiple tokens** but only paid the high computational cost of the `gpt2-medium` model **once**. 

The potential for a 2-3x speedup comes from this simple fact: if the small draft model is a reasonably good guesser, we can accept many of its tokens in parallel, breaking the one-token-at-a-time bottleneck of traditional autoregressive generation. And because the powerful target model always gets the final say, there is **no loss in output quality**.

In [9]:
# Clean up models from memory
del target_model
del draft_model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("\nCleaned up models and emptied CUDA cache.")


Cleaned up models and emptied CUDA cache.
