# CSE 234 Programming Assignment 3: Speculative Decoding

## Setup

In [1]:
import os
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Tuple, Dict, Optional

  from .autonotebook import tqdm as notebook_tqdm


## Speculative Decoding

In [2]:
class SpeculativeDecoder:
    def __init__(self, target_model_name: str, draft_model_name: str, device: str = "cuda"):
        """
        Initialize the speculative decoder with target and draft models.

        Args:
            target_model_name: HuggingFace model ID for the larger target model.
            draft_model_name: HuggingFace model ID for the smaller draft model.
            device: Device to run models on ("cuda" or "cpu").
        """
        self.device = device
        self.target_model, self.target_tokenizer = self.initialize_target_model(target_model_name)
        self.draft_model, self.draft_tokenizer = self.initialize_draft_model(draft_model_name)

        # Ensure tokenizers are compatible
        assert self.target_tokenizer.vocab == self.draft_tokenizer.vocab, "Tokenizers must be compatible"

    def initialize_target_model(self, model_name: str):
        """Initialize the larger target model with caching enabled and proper pad token."""
        print(f"Loading target model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement target model initialization
        # 1. Set the pad token if it doesn't exist
        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
    # Load model with appropriate settings for inference
        model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,  # Use half precision for better performance
        use_cache=True  # Enable KV cache for faster inference
    )
    
    # Move model to specified device
        model = model.to(self.device)
        model.eval()  # Set to evaluation mode
        return model, tokenizer

    def initialize_draft_model(self, model_name: str):
        """
        Initialize a smaller, faster draft model with proper pad token.
        Uses lower precision and additional optimizations.
        """
        print(f"Loading draft model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement draft model initialization
        # 1. Set the pad token if it doesn't exist
        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
    # Load model with optimizations for faster inference
        model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,  # Use half precision
        use_cache=True
    )
    
    # Move model to specified device
        model = model.to(self.device)
        model.eval()  # Set to evaluation mode

        return model, tokenizer

    def generate_draft_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                             num_speculative_tokens: int = 10) -> torch.Tensor:
        """
        Generate speculative tokens in one forward call using the draft model.

        Args:
            input_ids: Input token IDs (tensor of shape [1, seq_len]).
            attention_mask: Corresponding attention mask.
            num_speculative_tokens: Number of tokens to speculate.

        Returns:
            Tensor of shape [1, num_speculative_tokens] containing the draft tokens.
        """
        # TODO: Implement draft token generation
        # 1. Use the draft model to generate tokens
        # 2. Extract only the new tokens (not including the input)
        # 3. Return the newly generated tokens
        with torch.no_grad():
        # Generate tokens using the draft model
            outputs = self.draft_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=num_speculative_tokens,
            do_sample=False,  # Use greedy decoding
            pad_token_id=self.draft_tokenizer.pad_token_id,
            use_cache=True
        )
        
        # Extract only the new tokens (not including the input)
            new_tokens = outputs[:, input_ids.shape[1]:]
            return new_tokens

    def verify_tokens_vectorized(self, input_ids: torch.Tensor, draft_tokens: torch.Tensor,
                               attention_mask: torch.Tensor) -> Tuple[List[int], int]:
        """
        Vectorized verification: verify all draft tokens in one forward pass using the target model.

        Args:
            input_ids: The current input token IDs (shape [1, L]).
            draft_tokens: Draft tokens from the draft model (shape [1, k]).
            attention_mask: The current attention mask for input_ids.

        Returns:
            accepted_tokens: List of accepted token IDs.
            accepted_position: Index of the first rejected token (if all accepted, equals draft_tokens.shape[1]).
        """
        # TODO: Implement efficient verification of draft tokens
        # 1. Run target model on input_ids concatenated with draft_tokens
        # 2. Extract the logits for positions where draft tokens would be predicted
        # 3. Compare target model predictions with draft tokens
        # 4. Determine how many consecutive tokens were accepted before first mismatch
        combined_input = torch.cat([input_ids, draft_tokens], dim=1)
    
    # Create attention mask for combined input
        combined_attention_mask = torch.cat([
        attention_mask,
        torch.ones((1, draft_tokens.shape[1]), device=self.device)
    ], dim=1)
    
        with torch.no_grad():
        # Get logits from target model
            outputs = self.target_model(
            combined_input,
            attention_mask=combined_attention_mask,
            use_cache=True
        )
        
        # Get logits for positions where draft tokens would be predicted
            logits = outputs.logits[:, input_ids.shape[1]-1:-1]
        
        # Get predicted tokens
            predicted_tokens = torch.argmax(logits, dim=-1)
        
        # Compare predictions with draft tokens
            matches = (predicted_tokens == draft_tokens).squeeze(0)
        
        # Find first mismatch
            first_mismatch = torch.where(~matches)[0]
            if len(first_mismatch) == 0:
                return draft_tokens[0].tolist(), draft_tokens.shape[1]
        
            accepted_position = first_mismatch[0].item()
            accepted_tokens = draft_tokens[0, :accepted_position].tolist()
            return accepted_tokens, accepted_position

    def speculative_decode(self, prompt: str, max_tokens: int = 100,
                          num_speculative_tokens: int = 15) -> str:
        """
        Main speculative decoding algorithm with vectorized verification.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate (excluding prompt).
            num_speculative_tokens: Number of tokens to speculate per iteration.

        Returns:
            Generated text.
        """
        # Tokenize prompt
        inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        prompt_length = input_ids.shape[1]

        # Initialize counters for performance tracking
        total_tokens_generated = prompt_length
        total_draft_tokens_proposed = 0
        total_draft_tokens_accepted = 0
        start_time = time.time()

        # TODO: Implement the core speculative decoding loop
        # 1. Generate draft tokens using the draft model
        # 2. Verify draft tokens using the target model
        # 3. Accept verified tokens and append to the sequence
        # 4. For rejected tokens or if all tokens are accepted, generate a new token with the target model
        # 5. Stop when max_tokens is reached or an EOS token is generated
        while total_tokens_generated - prompt_length < max_tokens:
            # Generate draft tokens
            draft_tokens = self.generate_draft_tokens(
                input_ids, attention_mask, num_speculative_tokens
            )
            total_draft_tokens_proposed += draft_tokens.shape[1]

            # Verify draft tokens
            accepted_tokens, accepted_position = self.verify_tokens_vectorized(
                input_ids, draft_tokens, attention_mask
            )
            total_draft_tokens_accepted += accepted_position

            # Update input_ids with accepted tokens
            if accepted_position > 0:
                input_ids = torch.cat([
                    input_ids,
                    draft_tokens[:, :accepted_position]
                ], dim=1)
                attention_mask = torch.cat([
                    attention_mask,
                    torch.ones((1, accepted_position), device=self.device)
                ], dim=1)
                total_tokens_generated += accepted_position

            # If all draft tokens were accepted and we haven't reached max_tokens,
            # continue with the next batch
            if accepted_position == draft_tokens.shape[1]:
                continue

            # If we have rejected tokens, generate a new token with the target model
            with torch.no_grad():
                outputs = self.target_model(
                    input_ids,
                    attention_mask=attention_mask,
                    use_cache=True
                )
                next_token_logits = outputs.logits[:, -1, :]
                next_token = torch.argmax(next_token_logits, dim=-1)
                
                input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
                attention_mask = torch.cat([
                    attention_mask,
                    torch.ones((1, 1), device=self.device)
                ], dim=1)
                total_tokens_generated += 1

            # Check for EOS token
            if next_token.item() == self.target_tokenizer.eos_token_id:
                break
        # Calculate performance metrics
        elapsed_time = time.time() - start_time
        acceptance_rate = total_draft_tokens_accepted / total_draft_tokens_proposed if total_draft_tokens_proposed > 0 else 0

        print(f"Generated {total_tokens_generated - prompt_length} tokens in {elapsed_time:.2f} seconds")
        print(f"Tokens per second: {(total_tokens_generated - prompt_length) / elapsed_time:.2f}")
        print(f"Draft token acceptance rate: {acceptance_rate:.2%}")

        return self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def benchmark(self, prompt: str, max_tokens: int = 100,
                  num_runs: int = 3, compare_baseline: bool = True) -> Dict:
        """
        Benchmark the speculative decoder against baseline decoding.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate.
            num_runs: Number of benchmark runs.
            compare_baseline: Whether to compare with baseline (non-speculative) decoding.

        Returns:
            Dictionary with benchmark results.
        """
        results = {
            "speculative": {"times": [], "tokens_per_second": []},
            "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
        }

        # Benchmark speculative decoding.
        for _ in range(num_runs):
            start_time = time.time()
            output = self.speculative_decode(prompt, max_tokens=max_tokens)
            elapsed = time.time() - start_time
            prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
            output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
            tps = output_tokens / elapsed
            results["speculative"]["times"].append(elapsed)
            results["speculative"]["tokens_per_second"].append(tps)

        # Benchmark baseline decoding.
        if compare_baseline:
            for _ in range(num_runs):
                inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                start_time = time.time()
                with torch.no_grad():
                    output_ids = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + max_tokens,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id
                    )
                elapsed = time.time() - start_time
                output_tokens = output_ids.shape[1] - input_ids.shape[1]
                tps = output_tokens / elapsed
                results["baseline"]["times"].append(elapsed)
                results["baseline"]["tokens_per_second"].append(tps)

        for method in results.keys():
            if results[method] is not None:
                avg_time = sum(results[method]["times"]) / num_runs
                avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
                results[method]["avg_time"] = avg_time
                results[method]["avg_tokens_per_second"] = avg_tps

        if compare_baseline:
            speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
            results["speedup"] = speedup
            results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
            # print(f"Speculative decoding speedup: {speedup:.2f}x")
            # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

        return results

## Test

In [5]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: EleutherAI/pythia-1.4b-deduped
Loading draft model: EleutherAI/pythia-160m-deduped

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 106 tokens in 1.40 seconds
Tokens per second: 75.70
Draft token acceptance rate: 87.50%
Generated 106 tokens in 1.06 seconds
Tokens per second: 99.74
Draft token acceptance rate: 87.50%
Generated 106 tokens in 1.06 seconds
Tokens per second: 99.84
Draft token acceptance rate: 87.50%
Average speculative decoding time: 1.18 seconds
Average speculative tokens per second: 91.70
Average baseline decoding time: 1.43 seconds
Average baseline tokens per second: 69.73
Speedup: 1.22x
Latency reduction: 18.02%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 114 tokens in 1.08 seconds
Tokens per second: 105.93
Draft token acceptance rate: 94.17%
Generated 114 tokens in 1.07 seconds
Tokens per second: 107.01
Draft token acceptance rate: 94.17%
Generated 114 to

## Bonus

In [15]:
target_model_name = "facebook/opt-350m"  # Larger target model
draft_model_name = "facebook/opt-125m"   #"tiiuae/falcon-rw-460m" Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: facebook/opt-350m
Loading draft model: facebook/opt-125m

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 106 tokens in 1.35 seconds
Tokens per second: 78.44
Draft token acceptance rate: 87.50%
Generated 106 tokens in 0.81 seconds
Tokens per second: 131.12
Draft token acceptance rate: 87.50%
Generated 106 tokens in 0.78 seconds
Tokens per second: 136.21
Draft token acceptance rate: 87.50%
Average speculative decoding time: 0.98 seconds
Average speculative tokens per second: 115.15
Average baseline decoding time: 1.04 seconds
Average baseline tokens per second: 95.97
Speedup: 1.06x
Latency reduction: 5.93%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 106 tokens in 0.78 seconds
Tokens per second: 135.76
Draft token acceptance rate: 87.50%
Generated 106 tokens in 0.78 seconds
Tokens per second: 135.49
Draft token acceptance rate: 87.50%
Generated 106 tokens in 0.78 seconds
Tok

Analysis:
1. Optimizations Implemented
1.1 Model Initialization Optimizations
Used float16 for both models to reduce memory usage and improve computation speed
Enabled KV caching (use_cache=True) to reuse previously computed key-value pairs
Set models to evaluation mode to disable dropout and other training features
1.2 Vectorized Verification
Implemented an efficient verification method that evaluates all draft tokens in a single forward pass
Used tensor concatenation to combine input and draft tokens, avoiding multiple forward passes
Optimized the comparison of target model predictions with draft tokens using vectorized operations
1.3 Model Pair Selection
Experimented with multiple model pairs from the same family to ensure tokenizer compatibility
Found that smaller size differences between target and draft models improve token acceptance rates
Demonstrated that models from the same architecture family perform significantly better
2. Performance Results and Analysis
2.1 Pythia Model Pair
Using EleutherAI/pythia-1.4b-deduped (target) + EleutherAI/pythia-160m-deduped (draft):
Average speedup: 1.22x - 1.38x across different prompts
Token acceptance rates: 87.50% - 94.17%
Consistent performance across all tested prompts
2.2 OPT Model Pair
Using facebook/opt-350m (target) + facebook/opt-125m (draft):
Average speedup: 1.06x - 1.52x
Token acceptance rates: 87.50% - 96.19%
Best performance on narrative/creative tasks (highest speedup on the "Happy Birthday" prompt)
2.4 Parameter Impact Analysis
Number of speculative tokens: Using 15 tokens provided good balance between overgeneration and efficiency
Model size ratio: ~8.75x ratio (Pythia 1.4B vs 160M) yielded excellent acceptance rates
Domain alignment: Models trained on similar data demonstrated higher acceptance rates
3. Challenges and Solutions
3.1 Tokenizer Compatibility
Challenge: Ensuring tokenizers between target and draft models are compatible
Solution: Used models from the same family and added explicit compatibility check
3.2 Low Acceptance Rates
Challenge: Some model pairs (e.g., CodeGen) had poor alignment, causing slowdowns
Solution: Prioritized model pairs with demonstrated compatibility based on benchmarks
3.3 Verification Logic
Challenge: Efficiently determining which tokens to accept from draft sequence
Solution: Implemented vectorized verification that correctly identifies the first token mismatch
4. Conclusion
My speculative decoding implementation achieved significant speedups (up to 1.52x) with high token acceptance rates (up to 96.19%) using well-matched model pairs. The OPT model pair demonstrated the best overall performance, with the Pythia pair showing excellent consistency across different tasks.
The key factors for successful speculative decoding are:
Model pairs from the same architecture family
Appropriate size differentials between target and draft models
Efficient verification logic that minimizes computational overhead
Optimized model configurations (precision, caching)
These results demonstrate that speculative decoding offers a practical approach to accelerating text generation without sacrificing output quality.
