# Exploration Notebook
This notebook contains our initial Exploration into and implementation of speculative decoding.  We used it to test different architectures' performance on code and noncode tasks, and as the basis for more.

In [None]:
!pip install -q transformers accelerate torch

In [None]:
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 1. Load Tokenizer (uncomment as desired)
# checkpoint_verifier = "gpt2-large"
# checkpoint_draft = "distilgpt2"
# checkpoint_verifier = "Qwen/Qwen2.5-7B"
# checkpoint_draft = "Qwen/Qwen2.5-0.5B"
checkpoint_verifier = "EleutherAI/pythia-12b"
checkpoint_draft = "EleutherAI/pythia-70m"

tokenizer = AutoTokenizer.from_pretrained(checkpoint_verifier)

# 2. Load Verifier Model
print(f"Loading verifier model: {checkpoint_verifier}...")
verifier_model = AutoModelForCausalLM.from_pretrained(
    checkpoint_verifier,
    device_map="auto"
)

# 3. Load Draft Model
print(f"Loading draft model: {checkpoint_draft}...")
draft_model = AutoModelForCausalLM.from_pretrained(
    checkpoint_draft,
    device_map="auto"
)

# Ensure models are in eval mode
verifier_model.eval()
draft_model.eval()

print("Models loaded successfully!")

In [None]:
def standard_autoregressive_generation(model, input_ids, max_new_tokens):
    """
    Generates text using standard autoregressive decoding.
    """
    start_time = time.time()

    # Generate
    output = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False, # Greedy decoding
        pad_token_id=tokenizer.eos_token_id
    )

    end_time = time.time()
    latency = end_time - start_time
    return output, latency

In [None]:
@torch.no_grad()
def speculative_decoding(verifier, draft, input_ids, max_new_tokens, gamma=4):
    """
    Implementation of Speculative Decoding.
    gamma (k): Number of tokens the draft model guesses at once.
    """
    start_time = time.time()

    total_draft_tokens = 0
    accepted_draft_tokens = 0

    curr_input_ids = input_ids.clone()

    while curr_input_ids.shape[1] < input_ids.shape[1] + max_new_tokens:

        # Step 1: Draft Model generates gamma (k) tokens
        draft_outputs = draft.generate(
            curr_input_ids,
            max_new_tokens=gamma,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

        draft_tokens = draft_outputs[0, curr_input_ids.shape[1]:]

        # Step 2: Verifier checks the draft
        verifier_input = torch.cat([curr_input_ids, draft_tokens.unsqueeze(0)], dim=1)

        verifier_outputs = verifier(verifier_input)
        logits = verifier_outputs.logits

        predicted_tokens = torch.argmax(logits[0, start_pos:end_pos], dim=-1)

        # Step 3: Acceptance Loop
        n_matches = 0
        for i in range(len(draft_tokens)):
            if draft_tokens[i] == predicted_tokens[i]:
                n_matches += 1
            else:
                break

        total_draft_tokens += len(draft_tokens)
        accepted_draft_tokens += n_matches

        # Step 4: Append Accepted Tokens
        accepted_sequence = draft_tokens[:n_matches]
        curr_input_ids = torch.cat([curr_input_ids, accepted_sequence.unsqueeze(0)], dim=1)

        # Step 5: Correction 
        if n_matches < len(draft_tokens):
            correction_token = predicted_tokens[n_matches]
            curr_input_ids = torch.cat([curr_input_ids, correction_token.unsqueeze(0).unsqueeze(0)], dim=1)

        if curr_input_ids.shape[1] >= input_ids.shape[1] + max_new_tokens:
            break

    end_time = time.time()
    latency = end_time - start_time

    acceptance_rate = accepted_draft_tokens / total_draft_tokens if total_draft_tokens > 0 else 0

    return curr_input_ids, latency, acceptance_rate

In [None]:
def benchmark_speculative_decoding(prompt, max_new_tokens=200, gamma=5):
    """
    Runs and compares standard vs speculative decoding for a given prompt.
    """

    # 1. Prepare Inputs
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids

    print(f"PROMPT: {prompt.strip()[:100]}..." if len(prompt) > 100 else f"PROMPT: {prompt.strip()}")
    print(f"Settings: max_new_tokens={max_new_tokens}, gamma={gamma}")

    # 2. Run Standard Decoding (Baseline)
    print("Running Standard Decoding (Baseline)...")
    baseline_output, baseline_time = standard_autoregressive_generation(
        verifier_model, input_ids, max_new_tokens
    )
    baseline_text = tokenizer.decode(baseline_output[0], skip_special_tokens=True)
    print(f"Baseline Time: {baseline_time}s")

    # 3. Run Speculative Decoding
    print("\nRunning Speculative Decoding...")
    spec_output, spec_time, acc_rate = speculative_decoding(
        verifier_model, draft_model, input_ids, max_new_tokens, gamma=gamma
    )

    # 4. Alignment & Comparison
    min_len = min(baseline_output.shape[1], spec_output.shape[1])
    baseline_output_trunc = baseline_output[:, :min_len]
    spec_output_trunc = spec_output[:, :min_len]

    spec_text = tokenizer.decode(spec_output_trunc[0], skip_special_tokens=True)
    print(f"Speculative Time: {spec_time}s")
    print(f"Acceptance Rate: {acc_rate}")

    # 5. Metrics
    speedup = baseline_time / spec_time
    print(f"SPEEDUP: {speedup}x")

    # 6. Validation
    if torch.all(baseline_output_trunc == spec_output_trunc):
        print("SUCCESS: Outputs match exactly!")
    else:
        print("NOTE: Outputs differ.")
        matches = (baseline_output_trunc == spec_output_trunc).sum().item()
        total = baseline_output_trunc.shape[1]
        print(f"Consistency: {matches}/{total} tokens matched ({(matches/total):.1%})")

    # 7. Print Outputs
    print("Baseline Output")
    print(baseline_text)
    print("Speculative Output")
    print(spec_text)

In [None]:
GAMMA = 5

prompt = "The quick brown fox jumps over the"
max_new_tokens = 20
gamma = GAMMA
benchmark_speculative_decoding(prompt, max_new_tokens, gamma)

prompt2 = """def fibonacci(n):
    \"\"\"
    Returns the nth number in the fibonacci sequence.
    \"\"\"
"""
max_new_tokens2 = 100
gamma2 = GAMMA
benchmark_speculative_decoding(prompt2, max_new_tokens2, gamma2)

prompt3 = """
def two_sum(nums, target):
    num_map = {}
    for i, num in enumerate(nums):
"""
max_new_tokens3 = 100
gamma3 = GAMMA
benchmark_speculative_decoding(prompt3, max_new_tokens3, gamma3)
