In [67]:
# ============================
# Medusa-lite flow (beginner)
# drafter → verifier → multi-branch prefix-accept
# ============================

In [68]:
# ---------- Step 1) Imports ----------
from dataclasses import dataclass
import time, math, random
import torch
from typing import List, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM

In [69]:
# ---------- Step 2) Pick device ----------
def pick_device():
    """Pick the best available device (Apple Silicon -> mps, CUDA GPU -> cuda, else cpu)."""
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"

DEVICE = pick_device()
print("✅ DEVICE =", DEVICE)
assert DEVICE in {"cpu", "cuda", "mps"}

✅ DEVICE = mps


In [70]:
# ---------- Step 3) Fix seeds (for reproducibility) ----------
print("✅ seeding...")
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

✅ seeding...


In [71]:
# ---------- Step 4) Config ----------
@dataclass
class Cfg:
    # Models
    DRAFTER_ID: str = "distilgpt2"     # small, fast drafter
    VERIFIER_ID: str = "gpt2-medium"   # stronger verifier

    # Generation limits
    MAX_NEW_TOKENS: int = 30

    # Drafter sampling (diversity)
    TEMPERATURE: float = 0.8           # mild temperature
    TOP_P: float = 0.9                 # nucleus sampling

    # Repetition control for verifier
    REPETITION_PENALTY: float = 1.3    # penalize repeated tokens
    NO_REPEAT_NGRAM: int = 5           # block repeated n-grams

    # Medusa-lite branching
    TOPK_BRANCH: int = 3               # number of branches (K)
    DRAFT_SPAN: int = 3                # tokens per branch (M)

    # Misc
    DEVICE: str = DEVICE
    DEBUG: bool = False

cfg = Cfg()

In [72]:
# ---------- Step 5) Load models & tokenizers ----------
drafter_tok  = AutoTokenizer.from_pretrained(cfg.DRAFTER_ID)
verifier_tok = AutoTokenizer.from_pretrained(cfg.VERIFIER_ID)

# GPT-2 family often lacks eos/pad by default -> set them
if verifier_tok.eos_token_id is None:
    verifier_tok.eos_token = ""
if verifier_tok.pad_token_id is None:
    verifier_tok.pad_token = verifier_tok.eos_token
EOS_ID = verifier_tok.eos_token_id

drafter  = AutoModelForCausalLM.from_pretrained(cfg.DRAFTER_ID).to(cfg.DEVICE).eval()
verifier = AutoModelForCausalLM.from_pretrained(cfg.VERIFIER_ID).to(cfg.DEVICE).eval()
drafter.config.use_cache  = True
verifier.config.use_cache = True

print("✅ models ready:", cfg.DRAFTER_ID, "/", cfg.VERIFIER_ID)

✅ models ready: distilgpt2 / gpt2-medium


In [73]:
# ---------- Step 6) Prepare prompt & context ----------
prompt = "In a distant future, a small crew of explorers discovers "
ctx = drafter_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
input_ids = ctx["input_ids"]
print("context ok?", ctx is not None, "| shape:", input_ids.shape)

context ok? True | shape: torch.Size([1, 12])


In [74]:
# ---------- Step 7) Drafter: sample K first tokens (basic) ----------
@torch.inference_mode()
def drafter_sample_first_tokens(model, ids, k: int, temperature=0.8, top_p=0.9) -> list[int]:
    logits = model(ids).logits[:, -1, :]
    probs  = torch.softmax(logits / max(temperature, 1e-6), dim=-1)[0]

    if top_p is not None:
        sorted_p, sorted_ix = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_p, dim=0)
        keep = cumsum <= top_p; keep[0] = True
        pool_ix = sorted_ix[keep]
        pool_p  = probs[pool_ix] / probs[pool_ix].sum()
        num = min(k, pool_ix.numel())
        picks = torch.multinomial(pool_p, num_samples=num, replacement=False)
        firsts = [int(pool_ix[i]) for i in picks]
    else:
        num = min(k, probs.numel())
        picks = torch.multinomial(probs, num_samples=num, replacement=False)
        firsts = [int(i) for i in picks]

    # ✅ filter out pure whitespace tokens
    firsts = [t for t in firsts if not drafter_tok.decode([t]).isspace()]
    if not firsts:
        firsts = [int(torch.argmax(probs).item())]  # fallback
    return firsts[:k]

In [75]:
# ---------- Step 8) Drafter: greedy rollout of a branch ----------
@torch.inference_mode()
def drafter_rollout(ids, first_tok: int, span: int) -> list[int]:
    """
    Starting from 'first_tok', greedily extend with the drafter for (span-1) tokens.
    Returns a list[int] of length = span (including first_tok).
    """
    cur = torch.cat([ids, torch.tensor([[first_tok]], device=ids.device)], dim=1)
    seq = [first_tok]
    for _ in range(span - 1):
        logits = drafter(cur).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])
        seq.append(nxt)
        cur = torch.cat([cur, torch.tensor([[nxt]], device=cur.device)], dim=1)
    return seq

In [76]:
# ---------- Step 9) Drafter: propose K branches (each length = span) ----------
@torch.inference_mode()
def drafter_propose(ids, k: int, span: int, temperature=0.8, top_p=0.9) -> list[list[int]]:
    """
    Propose K branches:
      1) sample K first tokens
      2) for each, greedy-rollout to length 'span'
    Returns: list of branches (list[list[int]]).
    """
    firsts = drafter_sample_first_tokens(drafter, ids, k, temperature, top_p)
    branches = [drafter_rollout(ids, f, span) for f in firsts]
    return branches

In [77]:
# ---------- (Optional) Inspect a whitespace token example ----------
tid = 1849
print("token str (repr):", repr(drafter_tok.decode([tid])))
print("gpt2 piece:", drafter_tok.convert_ids_to_tokens([tid])[0])
print("is space?", drafter_tok.decode([tid]).isspace())

token str (repr): '\xa0'
gpt2 piece: Âł
is space? True


In [78]:
# ---------- Step 10) Verifier: predict next token (greedy) ----------
@torch.inference_mode()
def verifier_next_token(ids, repetition_penalty: float = 1.35, no_repeat_ngram: int = 5) -> int:
    """
    Greedy next token with simple repetition controls:
    - repetition_penalty: down-weight tokens that already appeared
    - no_repeat_ngram: block the n-th token if last (n-1) pattern already occurred
    """
    logits = verifier(ids).logits[:, -1, :].clone()
    V = logits.size(-1)

    # 1) repetition penalty
    if repetition_penalty and repetition_penalty != 1.0:
        seen = torch.bincount(ids[0].to(torch.int64), minlength=V).bool()
        logits[:, seen] = logits[:, seen] / repetition_penalty

    # 2) no-repeat n-gram
    n = int(no_repeat_ngram or 0)
    if n > 1 and ids.size(1) >= n - 1:
        tail = ids[0].tolist()
        prefix = tail[-(n-1):]
        blocked = set()
        for i in range(len(tail) - n + 1):
            if tail[i:i+n-1] == prefix:
                blocked.add(tail[i+n-1])
        if blocked:
            logits[:, list(blocked)] = -1e9

    return int(torch.argmax(logits, dim=-1)[0])


In [79]:
@torch.inference_mode()
def next_human_token(ids, tokenizer, tries=10):
    """
    Run greedy next-token repeatedly until we see a printable, non-whitespace char,
    or we hit 'tries' attempts. Useful when the immediate next token is whitespace.
    """
    cur = ids.clone()
    for _ in range(tries):
        tid = verifier_next_token(cur)
        s = tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
        if any(ch.isprintable() and not ch.isspace() for ch in s):
            return tid, s
        cur = torch.cat([cur, torch.tensor([[tid]], device=ids.device)], dim=1)
    return tid, s

vid = verifier_next_token(input_ids)
info = pretty_token(verifier_tok, vid)
print("predicted token info:", info)
t2, s2 = next_human_token(input_ids, verifier_tok)
print("next visible token:", t2, repr(s2))

predicted token info: {'id': 220, 'decode_repr': "' '", 'token_repr': "'Ġ'", 'token_fixed': "'Ġ'", 'codepoints': ['0x20'], 'bytes': [32]}
next visible token: 257 ' a'


In [80]:
# ---------- Step 11) Prefix-accept until first mismatch ----------
@torch.inference_mode()
def accept_until_mismatch(context_ids, branch_tokens: List[int]) -> Tuple[torch.Tensor, List[int], bool]:
    """
    Simulate prefix-accept with the verifier:
      - If verifier's next == branch token -> accept (append branch token)
      - On first mismatch -> append verifier token instead and stop
    Returns: (new_ids, accepted_tokens_list, mismatched_flag)
    """
    ids = context_ids.clone()
    accepted = []
    mismatched = False
    for tid in branch_tokens:
        pred = verifier_next_token(ids)
        if pred == tid:
            ids = torch.cat([ids, torch.tensor([[tid]], device=ids.device)], dim=1)
            accepted.append(tid)
        else:
            ids = torch.cat([ids, torch.tensor([[pred]], device=ids.device)], dim=1)
            mismatched = True
            break
    return ids, accepted, mismatched

# Example: build branches then test the first one
b = drafter_propose(input_ids, cfg.TOPK_BRANCH, cfg.DRAFT_SPAN, cfg.TEMPERATURE, cfg.TOP_P)
new_ids, accepted, mism = accept_until_mismatch(input_ids, b[0])
print('accepted len:', len(accepted), '| mismatched?', mism)
print('new length:', new_ids.shape[1], '| appended tokens:', new_ids.shape[1] - input_ids.shape[1])

accepted len: 0 | mismatched? True
new length: 13 | appended tokens: 1


In [81]:
# ---------- Step 12) Branch scoring ----------
def score_branch(accepted, mismatched):
    """
    Simple score: number of accepted tokens minus a small penalty if mismatch happened.
    Larger is better.
    """
    return len(accepted) - (1 if mismatched else 0)

print("score sanity:", score_branch([1,2,3], False), score_branch([1,2], True))  # 3, 1


score sanity: 3 1


In [82]:
# ---------- Step 13) Utilities for encoding ----------
@torch.inference_mode()
def encode_prompt(prompt: str):
    # use verifier tokenizer for consistency
    return verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)["input_ids"]

import re
_NBSP = "\u00A0"
_RE_PUNCT = re.compile(r"\s+([,.;:!?])")
_RE_SPACES = re.compile(r"\s+")

def normalize_text(s: str) -> str:
    s = s.replace(_NBSP, " ")     # NBSP → normal space
    s = _RE_PUNCT.sub(r"\1", s)   # trim space before punctuation
    s = _RE_SPACES.sub(" ", s)    # collapse multiple spaces
    return s.strip()


In [83]:
# ---------- Step 14) One Medusa step ----------
@torch.inference_mode()
def medusa_step(ids, topk_branch: int, draft_span: int, temperature: float):
    """
    One iteration:
      - Draft K branches of length M
      - For each, run prefix-accept until mismatch
      - Pick the best by 'score_branch'
      - Return the updated token ids
    """
    branches = drafter_propose(ids, topk_branch, draft_span, temperature, cfg.TOP_P)
    best_score = -10**9
    best_ids = None
    for br in branches:
        new_ids, accepted, mism = accept_until_mismatch(ids, br)
        s = score_branch(accepted, mism)
        if s > best_score:
            best_score, best_ids = s, new_ids
    return best_ids

ids2 = medusa_step(ids, cfg.TOPK_BRANCH, cfg.DRAFT_SPAN, cfg.TEMPERATURE)
print("before:", ids.shape[1], "→ after:", ids2.shape[1])

before: 6 → after: 7


In [84]:
# ---------- Step 15) Orchestrator: Medusa-lite generator (with sentence stop) ----------
import re

@torch.inference_mode()
def medusa_generate(
    prompt: str,
    max_new_tokens: int | None = None,
    topk_branch: int | None = None,
    draft_span: int | None = None,
    temperature: float | None = None,
    *,
    stop_at_sentence: bool = True,   # stop when sentence ends
    tail_chars: int = 80             # how many chars to inspect for sentence end
) -> str:
    """
    Run multiple Medusa steps until we reach 'max_new_tokens'.
    Adds:
      - sentence-level early stop (., !, ? with optional closing quotes/brackets)
      - safety: break if a step makes no progress
    """
    if max_new_tokens is None: max_new_tokens = cfg.MAX_NEW_TOKENS
    if topk_branch   is None: topk_branch   = cfg.TOPK_BRANCH
    if draft_span    is None: draft_span    = cfg.DRAFT_SPAN
    if temperature   is None: temperature   = cfg.TEMPERATURE

    # compile once (local to keep function drop-in)
    SENT_END = re.compile(r'[\.!\?]["\')\]]?\s*$')

    ids = encode_prompt(prompt)
    start_len = ids.shape[1]

    steps = math.ceil(max_new_tokens / draft_span)
    for _ in range(steps):
        before = ids.shape[1]
        ids = medusa_step(ids, topk_branch, draft_span, temperature)

        # safety: no progress → break
        if ids.shape[1] == before:
            break

        # reached target token budget → break
        if ids.shape[1] - start_len >= max_new_tokens:
            break

        # sentence-level early stop
        if stop_at_sentence:
            tail = verifier_tok.decode(ids[0][-min(tail_chars, ids.shape[1]):], skip_special_tokens=True)
            if SENT_END.search(tail):
                break

    # decode final text (GPT-2 family tokenizers are compatible)
    return drafter_tok.decode(ids[0], skip_special_tokens=True)


In [85]:
# Try once
out = medusa_generate("In a distant future, ", 40)
print(out)

In a distant future,      is a   ?


In [86]:
# ---------- Step 16) Greedy baseline (verifier-only) ----------
@torch.inference_mode()
def greedy_generate(prompt: str, max_new_tokens: int = None) -> str:
    """
    Plain greedy decoding with the verifier (for comparison).
    """
    if max_new_tokens is None:
        max_new_tokens = cfg.MAX_NEW_TOKENS

    ctx = verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    ids = ctx["input_ids"]

    for _ in range(max_new_tokens):
        logits = verifier(ids).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])
        ids = torch.cat([ids, torch.tensor([[nxt]], device=ids.device)], dim=1)

    return verifier_tok.decode(ids[0], skip_special_tokens=True)

txt = greedy_generate("In a distant future, ", 40)
print(txt)

In a distant future,  the world is ruled by a dictator who is obsessed with the idea of controlling the world's resources.  He wants to control the world's resources so that he can rule the world. 


In [87]:
# ---------- Step 17) A/B timing: greedy vs medusa ----------
def time_it(fn, *args, **kwargs):
    t0 = time.perf_counter()
    out = fn(*args, **kwargs)
    return out, time.perf_counter() - t0

g_txt, g_t = time_it(greedy_generate, "In a distant future, ", 80)
m_txt, m_t = time_it(medusa_generate, "In a distant future, ", 80)

print("⏱ greedy:", round(g_t, 3), "s")
print("⏱ medusa:", round(m_t, 3), "s")
print("\n--- greedy ---\n", g_txt[:400])
print("\n--- medusa ---\n", m_txt[:400])

⏱ greedy: 3.478 s
⏱ medusa: 0.786 s

--- greedy ---
 In a distant future,  the world is ruled by a dictator who is obsessed with the idea of controlling the world's resources.  He wants to control the world's resources so that he can rule the world.  He wants to control the world's resources so that he can rule the world.  He wants to control the world's resources so that he can rule the world.  He wants to

--- medusa ---
 In a distant future,      is a   ?
