# Exercise: Implementing and Tuning Speculative Decoding

**Welcome to the Exercise!**

In this exercise, we will implement a simplified version of a speculative decoding loop from scratch. This will give you a deep, hands-on understanding of how this powerful technique works to accelerate inference.

**Our Goal:**
1.  Implement a standard **autoregressive decoding** loop to serve as our performance baseline.
2.  Implement a **speculative decoding** loop using a small "draft" model (`gpt2`) and a larger "target" model (`gpt2-medium`).
3.  Run experiments by varying the number of draft tokens (`K`) to find the **optimal value** that balances speed and efficiency.
4.  Analyze the results to understand the trade-offs between `K`, the number of target model calls, and overall wall-clock time.

## 1. Environment Setup

First, let's install the necessary libraries.

In [1]:
!pip install transformers torch accelerate pandas



## 2. Imports and Experiment Configuration

Next, we'll import our libraries and configure the key parameters for the exercise, including the models we'll use and the different values of `K` (draft length) we want to test.

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import pandas as pd
import gc

# --- Configuration ---
TARGET_MODEL_NAME = "gpt2-medium"
DRAFT_MODEL_NAME = "gpt2"

PROMPT_TEXT = "The future of artificial intelligence is rapidly transforming the world by"
MAX_TOTAL_TOKENS = 60      # Total new tokens to generate in each run
K_VALUES_TO_TEST = [1, 2, 3, 4, 5, 8, 10] # Different draft lengths to experiment with

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## 3. Load Models and Tokenizer

Let's load our two models: the small, fast `gpt2` as our **Draft Model** and the larger, more accurate `gpt2-medium` as our **Target Model**.

In [3]:
print(f"Loading Target Model: {TARGET_MODEL_NAME}...")
target_model = AutoModelForCausalLM.from_pretrained(TARGET_MODEL_NAME).to(device)
target_model.eval()

print(f"Loading Draft Model: {DRAFT_MODEL_NAME}...")
draft_model = AutoModelForCausalLM.from_pretrained(DRAFT_MODEL_NAME).to(device)
draft_model.eval()

tokenizer = AutoTokenizer.from_pretrained(DRAFT_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Loading Target Model: gpt2-medium...
Loading Draft Model: gpt2...


## 4. Step 1: Establish a Baseline

Before we test our new method, we need to know how the standard approach performs. We'll implement a simple **autoregressive decoding loop** that uses only the slow target model to generate one token at a time. We will measure its wall-clock time and count how many times it has to call the target model.

In [4]:
def run_baseline_generation(target_model, tokenizer, prompt_text, max_tokens):
    """Generates text using standard autoregressive decoding and measures performance."""
    input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)
    target_passes = 0
    
    start_time = time.time()
    with torch.no_grad():
        # The generate function encapsulates this loop, but we write it out for clarity
        for _ in range(max_tokens):
            outputs = target_model(input_ids)
            target_passes += 1
            
            next_token_id = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)
            input_ids = torch.cat([input_ids, next_token_id], dim=1)
            
            if next_token_id.item() == tokenizer.eos_token_id:
                break
    end_time = time.time()
    
    return end_time - start_time, target_passes

print("--- Running Baseline Autoregressive Generation ---")
baseline_time, baseline_passes = run_baseline_generation(target_model, tokenizer, PROMPT_TEXT, MAX_TOTAL_TOKENS)
print(f"Baseline Time: {baseline_time:.4f} s")
print(f"Baseline Target Model Passes: {baseline_passes}")

--- Running Baseline Autoregressive Generation ---
Baseline Time: 1.1746 s
Baseline Target Model Passes: 60


## 5. Step 2: Implement the Speculative Decoding Loop

Now for the core of the exercise. We will implement the speculative decoding logic in a single, well-commented function. This function will take `K` (the draft length) as an argument and perform the three key phases:
1.  **Draft:** The small model generates `K` candidate tokens.
2.  **Verify:** The large model checks all `K` tokens in a single forward pass.
3.  **Accept:** We compare the draft with the target's predictions and accept a sequence of tokens.

In [5]:
def run_speculative_decoding(draft_model, target_model, tokenizer, prompt_text, max_tokens, k):
    """Runs a speculative decoding loop for a given k and measures performance."""
    input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)
    target_passes = 0
    total_accepted_tokens = 0
    
    start_time = time.time()
    with torch.no_grad():
        while input_ids.shape[1] < (len(tokenizer.encode(prompt_text)) + max_tokens):
            # 1. Draft Phase: The small model generates K candidate tokens
            draft_ids = draft_model.generate(input_ids, max_new_tokens=k, pad_token_id=tokenizer.eos_token_id)
            draft_candidates = draft_ids[:, input_ids.shape[1]:]
            num_drafted = draft_candidates.shape[1]
            if num_drafted == 0: break # No more tokens can be drafted

            # 2. Verification Phase: The target model gets the draft + context
            verification_input = torch.cat([input_ids, draft_candidates], dim=1)
            target_logits = target_model(verification_input).logits
            target_passes += 1

            # 3. Acceptance Logic: Compare draft with target's preferences
            num_matched = 0
            for i in range(num_drafted):
                # Get the target's prediction for the i-th draft token position
                verification_logit_idx = input_ids.shape[1] + i - 1
                target_pred_id = torch.argmax(target_logits[:, verification_logit_idx, :], dim=-1)
                
                if draft_candidates[0, i] == target_pred_id.item():
                    num_matched += 1
                else:
                    break # Mismatch found, stop comparing
            
            # Accept all matched tokens
            accepted_tokens = draft_candidates[:, :num_matched]
            input_ids = torch.cat([input_ids, accepted_tokens], dim=1)
            
            # If there was a mismatch, accept the target's correction
            if num_matched < num_drafted:
                correction_logit_idx = input_ids.shape[1] -1
                correction_id = torch.argmax(target_logits[:, correction_logit_idx, :], dim=-1, keepdim=True)
                input_ids = torch.cat([input_ids, correction_id], dim=1)
            
            if tokenizer.eos_token_id in input_ids[0]: break

    end_time = time.time()
    
    total_accepted_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt_text))
    avg_accepted_per_pass = total_accepted_tokens / target_passes if target_passes > 0 else 0
    
    return end_time - start_time, target_passes, avg_accepted_per_pass

## 6. Step 3: Run the Experiment with Varying K

Now we can create a simple loop that calls our `run_speculative_decoding` function for each value of `K` we want to test. We'll store the results to analyze later.

In [6]:
results_log = []

print("--- Running Speculative Decoding Experiment ---")
for k in K_VALUES_TO_TEST:
    print(f"Testing with K = {k}...")
    spec_time, spec_passes, avg_accepted = run_speculative_decoding(
        draft_model, target_model, tokenizer, PROMPT_TEXT, MAX_TOTAL_TOKENS, k
    )
    results_log.append({
        "K": k,
        "Time (s)": spec_time,
        "Target Passes": spec_passes,
        "Avg. Accepted Tokens": avg_accepted
    })
    # Clean up GPU memory between runs
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


--- Running Speculative Decoding Experiment ---
Testing with K = 1...
Testing with K = 2...
Testing with K = 3...
Testing with K = 4...
Testing with K = 5...
Testing with K = 8...
Testing with K = 10...


## 7. Step 4: Consolidate and Analyze Results

Finally, let's put all our data into a clean table and analyze the results to find the optimal `K`.

In [7]:
df_results = pd.DataFrame(results_log)

print("--- Speculative Decoding Experiment Results Summary ---")
print(f"Baseline Performance: Time={baseline_time:.2f}s, Target Passes={baseline_passes}")
print(df_results.to_string())

--- Speculative Decoding Experiment Results Summary ---
Baseline Performance: Time=1.17s, Target Passes=60
    K  Time (s)  Target Passes  Avg. Accepted Tokens
0   1  1.861169             60              1.000000
1   2  1.211026             33              1.818182
2   3  1.060563             24              2.541667
3   4  1.047493             20              3.100000
4   5  1.021241             17              3.529412
5   8  1.186374             14              4.714286
6  10  1.204962             12              5.000000


### Guiding Questions for Analysis

1.  **Target Passes**: How did the number of expensive target model passes change as `K` increased? Was it always fewer than the baseline?
2.  **Wall-Clock Time**: How did the total generation time change with `K`? Was there an optimal `K` value that resulted in the fastest time? Why do you think time might start to increase again for very large `K`?
3.  **Average Accepted Tokens**: How did the average number of tokens accepted per verification step change with `K`? What does this metric tell you about the efficiency of the process?
4.  **Trade-offs & Conclusion**: What are the trade-offs of choosing a small `K` versus a large `K`? Based on your results, what would be the best `K` to use for this specific draft/target model pair?

### Sample Analysis

1.  **Target Passes**: As `K` increased, the number of target model passes decreased significantly and consistently. The baseline required 60 passes, while `K=10` required only a fraction of that. This is because a larger `K` allows more tokens to be verified in a single batch, reducing the number of verification steps needed to generate the full sequence.

2.  **Wall-Clock Time**: Time initially decreased as `K` went from 1 to 4, hitting an **optimal point around K=4 or K=5**. For `K` values larger than that, the total time began to increase again. This happens because while a large `K` reduces target passes, the draft model generates more tokens that are likely to be incorrect. The overhead of generating these useless draft tokens, combined with processing a larger verification batch in the target model, eventually outweighs the benefit of fewer verification steps.

3.  **Average Accepted Tokens**: The average number of tokens accepted per verification step consistently increased with `K`. This metric is a great measure of efficiency; a value greater than 1.0 means speculative decoding is outperforming the baseline (which accepts 1 token per pass). A higher number indicates the draft model's predictions are often correct, allowing us to accept multiple tokens for the cost of one target pass.

4.  **Trade-offs & Conclusion**:
    *   **Small `K` (e.g., 1-2)**: Safe and low overhead. The draft tokens are more likely to be correct, but the potential speedup is limited because you aren't trying to accept many tokens at once.
    *   **Large `K` (e.g., 8-10)**: High risk, high reward. It dramatically reduces target passes, but the draft model is more likely to make a mistake early on. The computational overhead of generating and verifying many draft tokens can negate the time savings.
    *   **Optimal `K` (e.g., 4-5)**: This is the "sweet spot." It's large enough to get a significant speedup by accepting multiple tokens per pass but small enough that the draft model remains accurate and the overhead doesn't become a bottleneck.

For this `gpt2`/`gpt2-medium` pairing, a **`K` value of 4 or 5 is the optimal choice**, providing the best wall-clock time speedup.