# CSE 234 Programming Assignment 3: Speculative Decoding

## Setup

In [1]:

import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import os
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Tuple, Dict, Optional

In [3]:
torch.set_float32_matmul_precision('high')

## Speculative Decoding

In [4]:

def print_tokens(label: str, tokens: torch.Tensor, tokenizer: AutoTokenizer):
    if len(tokens) == 0:
        print(f"<{label}>EMPTY</{label}>")
        return
    result = tokenizer.decode(tokens, skip_special_tokens=False)
    result = result.replace("\n", "<br>")
    print(f"<{label}>{result}</{label}>")

class SpeculativeDecoder:
    def __init__(self, target_model_name: str, draft_model_name: str, device: str = "cuda"):
        """
        Initialize the speculative decoder with target and draft models.

        Args:
            target_model_name: HuggingFace model ID for the larger target model.
            draft_model_name: HuggingFace model ID for the smaller draft model.
            device: Device to run models on ("cuda" or "cpu").
        """
        self.device = device
        self.target_model, self.target_tokenizer = self.initialize_target_model(target_model_name)
        self.draft_model, self.draft_tokenizer = self.initialize_draft_model(draft_model_name)

        print("Running on device:", self.device)

        # Ensure tokenizers are compatible
        assert self.target_tokenizer.vocab == self.draft_tokenizer.vocab, "Tokenizers must be compatible"

    def initialize_target_model(self, model_name: str):
        """Initialize the larger target model with caching enabled and proper pad token."""
        print(f"Loading target model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Implement target model initialization
        # 1. Set the pad token if it doesn't exist
        if tokenizer.pad_token is None:
            print("No pad token found, setting to eos token")
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        print("Pad token:", tokenizer.pad_token)
        print("Pad token ID:", tokenizer.pad_token_id)

        # 2. Load the model with appropriate settings for inference
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
        ).to(self.device)

        # 3. Enable any optimizations that might help with performance
        model.eval()
        if self.device == "cuda":
            print("Enabling PyTorch 2.0 optimizations")
            model = torch.compile(model)

        return model, tokenizer

    def initialize_draft_model(self, model_name: str):
        """
        Initialize a smaller, faster draft model with proper pad token.
        Uses lower precision and additional optimizations.
        """
        print(f"Loading draft model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Implement draft model initialization
        # 1. Set the pad token if it doesn't exist
        if tokenizer.pad_token is None:
            print("No pad token found, setting to eos token")
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        print("Pad token:", tokenizer.pad_token)
        print("Pad token ID:", tokenizer.pad_token_id)

        # 2. Load the model with appropriate settings for inference
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
        ).to(self.device)

        # 3. Enable any optimizations that might help with performance
        model.eval()
        if self.device == "cuda":
            print("Enabling PyTorch 2.0 optimizations")
            model = torch.compile(model)

        return model, tokenizer

    def generate_draft_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                             num_speculative_tokens: int = 10) -> torch.Tensor:
        """
        Generate speculative tokens in one forward call using the draft model.

        Args:
            input_ids: Input token IDs (tensor of shape [1, seq_len]).
            attention_mask: Corresponding attention mask.
            num_speculative_tokens: Number of tokens to speculate.

        Returns:
            Tensor of shape [1, num_speculative_tokens] containing the draft tokens.
        """
        # Implement draft token generation
        # 1. Use the draft model to generate tokens
        with torch.no_grad():
            output_ids = self.draft_model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=num_speculative_tokens,
                do_sample=False,
                pad_token_id=self.draft_tokenizer.pad_token_id
            )
        # 2. Extract only the new tokens (not including the input)
        draft_tokens = output_ids[:, input_ids.shape[1]:]

        # 3. Return the newly generated tokens
        return draft_tokens

    def verify_tokens_vectorized(self, input_ids: torch.Tensor, draft_tokens: torch.Tensor,
                               attention_mask: torch.Tensor, debug: bool = False) -> Tuple[List[int], int]:
        """
        Vectorized verification: verify all draft tokens in one forward pass using the target model.

        Args:
            input_ids: The current input token IDs (shape [1, L]).
            draft_tokens: Draft tokens from the draft model (shape [1, k]).
            attention_mask: The current attention mask for input_ids.

        Returns:
            accepted_tokens: List of accepted token IDs.
            accepted_position: Index of the first rejected token (if all accepted, equals draft_tokens.shape[1]).
            new_token_by_target: The new token predicted by the target model after all draft tokens are verified.
        """
        # Implement efficient verification of draft tokens
        # 1. Run target model on input_ids concatenated with draft_tokens
        combined_ids = torch.cat([input_ids, draft_tokens], dim=1)
        combined_attention_mask = torch.cat([ attention_mask, torch.ones_like(draft_tokens) ], dim=1)
        with torch.no_grad():
            outputs = self.target_model(combined_ids, attention_mask=combined_attention_mask)

        # 2. Extract the logits for positions where draft tokens would be predicted
        logits = outputs.logits
        assert logits.size(1) == combined_ids.size(1)
        prediction_logits = logits[0, input_ids.shape[1]-1:-1]
        predicted_tokens = torch.argmax(prediction_logits, dim=-1)

        if debug:
            print_tokens("predicted_tokens", predicted_tokens, self.target_tokenizer)

        # 3. Compare target model predictions with draft tokens
        matches = (predicted_tokens == draft_tokens[0])

        # 4. Determine how many consecutive tokens were accepted before first mismatch
        first_mismatch = torch.nonzero(~matches, as_tuple=False)
        if first_mismatch.numel() > 0:
            # The item() method returns the value of the tensor as a Python number,
            # and it only works for tensors with one element.
            accepted_position = first_mismatch[0].item()
        else:
            # All tokens were accepted
            accepted_position = draft_tokens.shape[1]
        
        # Get accepted tokens
        accepted_tokens = draft_tokens[0, :accepted_position].tolist()

        # Get the new token predicted by the target model after all draft tokens are verified
        position_of_output = input_ids.shape[1] - 1 + accepted_position
        new_token_by_target = torch.argmax(outputs.logits[0, position_of_output], dim=-1).item()
        
        return accepted_tokens, accepted_position, new_token_by_target


    def speculative_decode(self, prompt: str, max_tokens: int = 100,
                          num_speculative_tokens: int = 15, debug: bool = False) -> str:
        """
        Main speculative decoding algorithm with vectorized verification.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate (excluding prompt).
            num_speculative_tokens: Number of tokens to speculate per iteration.

        Returns:
            Generated text.
        """
        # Tokenize prompt
        inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)

        # A tensor indicating which tokens are actual text (1) and which are padding (0)
        attention_mask = inputs["attention_mask"].to(self.device)
        prompt_length = input_ids.shape[1]

        # Initialize counters for performance tracking
        total_tokens_generated = prompt_length
        total_draft_tokens_proposed = 0
        total_draft_tokens_accepted = 0
        start_time = time.time()

        # Implement the core speculative decoding loop
        # Main speculative decoding loop
        while total_tokens_generated - prompt_length < max_tokens:

            if debug:
                print('--------------------------------');

            # 1. Generate draft tokens using the draft model
            draft_tokens = self.generate_draft_tokens(
                input_ids,
                attention_mask,
                num_speculative_tokens=num_speculative_tokens
            )
            total_draft_tokens_proposed += draft_tokens.shape[1]

            if debug:
                print_tokens("draft_tokens", draft_tokens[0], self.target_tokenizer)

            # 2. Verify draft tokens using the target model
            accepted_tokens, accepted_position, new_token_by_target = self.verify_tokens_vectorized(
                input_ids,
                draft_tokens,
                attention_mask
            )
            total_draft_tokens_accepted += len(accepted_tokens)

            if debug:
                print(f"len(accepted_tokens): {len(accepted_tokens)}")
                print_tokens("accepted", accepted_tokens, self.target_tokenizer)

            # 3. Accept verified tokens and append to the sequence
            if accepted_tokens:
                new_tokens = torch.tensor([accepted_tokens], device=self.device)
                input_ids = torch.cat([input_ids, new_tokens], dim=1)
                attention_mask = torch.cat([
                    attention_mask,
                    torch.ones((1, len(accepted_tokens)), device=self.device)
                ], dim=1)
                total_tokens_generated += len(accepted_tokens)

            # 4. For rejected tokens or if all tokens are accepted, generate a new token with the target model

            input_ids = torch.cat([input_ids, torch.tensor([[new_token_by_target]], device=self.device)], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=self.device)], dim=1)
            total_tokens_generated += 1

            if debug:
                print_tokens("new_token", [new_token_by_target], self.target_tokenizer)

            # 5. Stop when max_tokens is reached or an EOS token is generated
            # Check if any token in the sequence is an EOS token
            if self.target_tokenizer.eos_token_id in input_ids[0]:
                break

        # Calculate performance metrics
        elapsed_time = time.time() - start_time
        acceptance_rate = total_draft_tokens_accepted / total_draft_tokens_proposed if total_draft_tokens_proposed > 0 else 0

        print(f"Generated {total_tokens_generated - prompt_length} tokens in {elapsed_time:.2f} seconds")
        print(f"Tokens per second: {(total_tokens_generated - prompt_length) / elapsed_time:.2f}")
        print(f"Draft token acceptance rate: {acceptance_rate:.2%}")

        return self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def benchmark(self, prompt: str, max_tokens: int = 100,
                  num_runs: int = 3, compare_baseline: bool = True) -> Dict:
        """
        Benchmark the speculative decoder against baseline decoding.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate.
            num_runs: Number of benchmark runs.
            compare_baseline: Whether to compare with baseline (non-speculative) decoding.

        Returns:
            Dictionary with benchmark results.
        """
        results = {
            "speculative": {"times": [], "tokens_per_second": []},
            "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
        }

        # Warm up the model
        _ = self.speculative_decode(prompt, max_tokens=max_tokens)

        # Benchmark speculative decoding.
        for _ in range(num_runs):
            start_time = time.time()
            output = self.speculative_decode(prompt, max_tokens=max_tokens)
            elapsed = time.time() - start_time
            prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
            output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
            tps = output_tokens / elapsed
            results["speculative"]["times"].append(elapsed)
            results["speculative"]["tokens_per_second"].append(tps)

        # Benchmark baseline decoding.
        if compare_baseline:
            for _ in range(num_runs):
                inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                start_time = time.time()
                with torch.no_grad():
                    output_ids = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + max_tokens,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id
                    )
                elapsed = time.time() - start_time
                output_tokens = output_ids.shape[1] - input_ids.shape[1]
                tps = output_tokens / elapsed
                results["baseline"]["times"].append(elapsed)
                results["baseline"]["tokens_per_second"].append(tps)

        for method in results.keys():
            if results[method] is not None:
                avg_time = sum(results[method]["times"]) / num_runs
                avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
                results[method]["avg_time"] = avg_time
                results[method]["avg_tokens_per_second"] = avg_tps

        if compare_baseline:
            speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
            results["speedup"] = speedup
            results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
            # print(f"Speculative decoding speedup: {speedup:.2f}x")
            # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

        return results

In [5]:

target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model

# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

import gc
gc.collect()

print("> Speculative decoding...")
decoder.speculative_decode("What is the capital of China?",
                          max_tokens=100, num_speculative_tokens=15, debug=True)


Loading target model: EleutherAI/pythia-1.4b-deduped
No pad token found, setting to eos token
Pad token: <|endoftext|>
Pad token ID: 0
Enabling PyTorch 2.0 optimizations
Loading draft model: EleutherAI/pythia-160m-deduped
No pad token found, setting to eos token
Pad token: <|endoftext|>
Pad token ID: 0
Enabling PyTorch 2.0 optimizations
Running on device: cuda
> Speculative decoding...
--------------------------------
<draft_tokens><br><br>The first of the "1"s is a (1)</draft_tokens>
len(accepted_tokens): 3
<accepted><br><br>The</accepted>
<new_token> capital</new_token>
--------------------------------
<draft_tokens> of China is the first, the most significant and the most significant of the</draft_tokens>
len(accepted_tokens): 3
<accepted> of China is</accepted>
<new_token> Beijing</new_token>
--------------------------------
<draft_tokens>,<br><br>The city,<br><br>The capital,<br><br>The city</draft_tokens>
len(accepted_tokens): 0
<accepted>EMPTY</accepted>
<new_token>.</new_token>

'What is the capital of China?\n\nThe capital of China is Beijing.\n\nWhat is the capital of the United States?\n\nWashington, D.C.\n\nWhat is the capital of France?\n\nParis\n\nWhat is the capital of Germany?\n\nBerlin\n\nWhat is the capital of Italy?\n\nRome\n\nWhat is the capital of Japan?\n\nTokyo\n\nWhat is the capital of Russia?\n\nMoscow\n\nWhat is the capital of Spain'

In [6]:

inputs = decoder.target_tokenizer("What is the capital of China?", return_tensors="pt", padding=True).to(decoder.device)
tokens = decoder.target_model.generate(**inputs)
print(f"<target>{decoder.target_tokenizer.decode(tokens[0], skip_special_tokens=False)}</target>")

inputs = decoder.draft_tokenizer("What is the capital of China?", return_tensors="pt", padding=True).to(decoder.device)
tokens = decoder.draft_model.generate(**inputs)
print(f"<draft>{decoder.draft_tokenizer.decode(tokens[0], skip_special_tokens=False)}</draft>")

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


<target>What is the capital of China?

The capital of China is Beijing.

What is the capital of the United States?</target>
<draft>What is the capital of China?

The first of the "1"s is a (1), and a (1</draft>


## Test

In [7]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: EleutherAI/pythia-1.4b-deduped
No pad token found, setting to eos token
Pad token: <|endoftext|>
Pad token ID: 0
Enabling PyTorch 2.0 optimizations
Loading draft model: EleutherAI/pythia-160m-deduped
No pad token found, setting to eos token
Pad token: <|endoftext|>
Pad token ID: 0
Enabling PyTorch 2.0 optimizations
Running on device: cuda

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 103 tokens in 1.06 seconds
Tokens per second: 96.82
Draft token acceptance rate: 62.00%
Generated 103 tokens in 1.06 seconds
Tokens per second: 97.46
Draft token acceptance rate: 62.00%
Generated 103 tokens in 1.06 seconds
Tokens per second: 97.06
Draft token acceptance rate: 62.00%
Generated 103 tokens in 1.06 seconds
Tokens per second: 97.10
Draft token acceptance rate: 62.00%
Average speculative decoding time: 1.06 seconds
Average speculative tokens per second: 96.22
Average baseline decoding time: 1.12 seconds
Average baseline tokens per second

## Bonus

In [8]:
target_model_name = ...  # Larger target model
draft_model_name = ...   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: Ellipsis


OSError: Ellipsis is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`