# Speculative decoding skeleton (exercise)

Reference: [../speculative_decoding.ipynb](../speculative_decoding.ipynb). Paper: https://arxiv.org/abs/2211.17192

**Goals:**
- **generate_vanilla** (optional): Autoregressive loop â€” one forward per token; argmax logits for next token; append until max_tokens or EOS.
- **generate_specdec_greedy**: (1) Draft phase: run draft model autoregressively to get `num_draft_tokens` candidates. (2) Target verification: one forward of target on (prompt + draft tokens). (3) Accept/reject: compare target next-token predictions to draft; find first disagreement; accept prefix + one target token; repeat. One target forward per block.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Setup: load draft and target models (same as reference). Use smaller models if needed for testing.
DRAFT_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
TARGET_NAME = "Qwen/Qwen2.5-3B-Instruct"
target = AutoModelForCausalLM.from_pretrained(TARGET_NAME, device_map="auto", trust_remote_code=True)
draft = AutoModelForCausalLM.from_pretrained(DRAFT_NAME, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(DRAFT_NAME, trust_remote_code=True)
prompt = "Hello, how are you?"

def dummy_generate(*args, **kwargs):
    raise NotImplementedError("generate() disabled")
draft.generate = dummy_generate
target.generate = dummy_generate

In [None]:
from time import time

def cat(a, b, c=None):
    if c is not None:
        return torch.cat([a, b, c], dim=1)
    return torch.cat([a, b], dim=1)

def generate_vanilla(model, tokenizer, prompt, max_tokens=256, profile=False):
    # TODO (optional): Autoregressive loop. input_ids = tokenizer(prompt).input_ids.cuda(). While count < max_tokens:
    #   logits = model(input_ids).logits; next_token = argmax(logits[:, -1:, :]); input_ids = cat(input_ids, next_token); break if EOS.
    raise NotImplementedError("TODO: implement generate_vanilla")

def generate_specdec_greedy(draft, target, tokenizer, prompt, num_draft_tokens=7, max_tokens=256, profile=False):
    # TODO: (1) Draft: from input_ids, run draft autoregressively num_draft_tokens times -> draft_ids.
    # (2) Target: full_seq = cat(input_ids, draft_ids); target_logits = target(full_seq).logits.
    # (3) Accept/reject: target_pred_ids = argmax(target_logits from pos len(input_ids)-1). First index where draft_ids != target_pred_ids[:-1].
    #     New input_ids = cat(input_ids, draft_ids[:, :first], target_pred_ids[:, first:first+1]). Repeat until max_tokens or EOS.
    raise NotImplementedError("TODO: implement generate_specdec_greedy")

# After implementing: print(generate_vanilla(target, tokenizer, prompt, max_tokens=256))
# print(generate_specdec_greedy(draft, target, tokenizer, prompt, max_tokens=256))