In [63]:
"""
Flow

"""

'\nFlow\n\n'

In [64]:
# ============================================================
# 0) Minimal setup: imports, device, and seed (modular)
# ============================================================
import math
import random
from typing import List, Optional, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [65]:
# ---------- 0.1 Device ----------
def pick_device() -> str:
    """Pick best available device: MPS (Apple) → CUDA → CPU."""
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"

DEVICE = pick_device()
print("✅ device:", DEVICE)

✅ device: mps


In [66]:
 """
    WHY DO WE SET A SEED?
    - Our code has two decoding paths:
        (1) Greedy decoding (deterministic): picks argmax every step -> NO randomness involved.
            Seeding is NOT required for pure greedy baselines.
        (2) Medusa-like path (sampling): uses torch.multinomial(...) to sample tokens
            with temperature/top-p -> this IS stochastic.
            Seeding makes the sampled outputs reproducible across runs.

    FLOW / HOW TO USE:
    - Keep seeding OFF when you run greedy-only baselines to show they are deterministic by nature.
    - Turn seeding ON right before you call any function that samples tokens
      (e.g., propose_branch / medusa_tiny) to get repeatable experiments.
    - Optionally turn it OFF again after the experiment if you want variability in later runs.

    EXAMPLES:
    - Deterministic greedy baseline:
        set_seed(None)                  # turn off seeding
        out_g = greedy_generate(...)

    - Reproducible sampling run:
        set_seed(42)                    # fix RNG state
        out_m = medusa_tiny(...)        # sampling -> same output every time with the same seed
        set_seed(None)                  # (optional) restore randomness for later calls
"""

'\n   WHY DO WE SET A SEED?\n   - Our code has two decoding paths:\n       (1) Greedy decoding (deterministic): picks argmax every step -> NO randomness involved.\n           Seeding is NOT required for pure greedy baselines.\n       (2) Medusa-like path (sampling): uses torch.multinomial(...) to sample tokens\n           with temperature/top-p -> this IS stochastic.\n           Seeding makes the sampled outputs reproducible across runs.\n\n   FLOW / HOW TO USE:\n   - Keep seeding OFF when you run greedy-only baselines to show they are deterministic by nature.\n   - Turn seeding ON right before you call any function that samples tokens\n     (e.g., propose_branch / medusa_tiny) to get repeatable experiments.\n   - Optionally turn it OFF again after the experiment if you want variability in later runs.\n\n   EXAMPLES:\n   - Deterministic greedy baseline:\n       set_seed(None)                  # turn off seeding\n       out_g = greedy_generate(...)\n\n   - Reproducible sampling run:\n  

In [67]:
def set_seed(seed: Optional[int] = 42) -> None:
    if seed is None:
        return  # no seeding -> each sampling run can differ
    import random
    random.seed(seed)            # Python RNG
    torch.manual_seed(seed)      # PyTorch CPU RNG
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # PyTorch CUDA RNG for all GPUs

In [68]:
set_seed(42)
out1 = medusa_tiny("Long ago, ", max_new_tokens=30)
set_seed(42)
out2 = medusa_tiny("Long ago, ", max_new_tokens=30)
print(out1 == out2)  # True

True


In [69]:
#load_tokenizer: wraps Hugging Face’s built-in AutoTokenizer.from_pretrained(...) and then patches special tokens for GPT-2–family models (which often ship without explicit EOS/PAD).

In [70]:
def load_tokenizer(model_id: str) -> AutoTokenizer:
    #uses Hugging Face's official API to loead the correct tokenizer efor model_id/distilgpt2
    """
    Load tokenizer and patch EOS/PAD for GPT-2-like models/ GPT2 uses BPE tokenizer
    (they often lack explicit eos/pad tokens).
    """
    tok = AutoTokenizer.from_pretrained(model_id)
    """
    In GPT-2 models, it’s common for eos_token_id to be unset.
We declare the special token that corresponds to the empty string ("")—typically ID 50256 in classic GPT-2—as the EOS to populate eos_token_id.
Why? Because eos_token_id is needed to terminate generation and to mask EOS in the early steps (to prevent immediate early termination), among other controls.
    """
    if tok.eos_token_id is None:
        tok.eos_token = ""         # End-Of-Sequence token text (common for GPT-2)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    return tok

In [71]:
def load_model(model_id: str, device: str) -> AutoModelForCausalLM:
    """Load causal LM on device with KV-cache enabled."""
    model = AutoModelForCausalLM.from_pretrained(model_id).to(device).eval()
    model.config.use_cache = True
    return model


MODEL_ID = "distilgpt2"  # small & fast for demos
tok = load_tokenizer(MODEL_ID)
model = load_model(MODEL_ID, DEVICE)

In [72]:
# ---------- 1.1 Special IDs ----------
def get_special_ids(tokenizer: AutoTokenizer) -> Tuple[Optional[int], Optional[int]]:
    """Return (EOS_ID, PAD_ID) from tokenizer."""
    return tokenizer.eos_token_id, tokenizer.pad_token_id

EOS_ID, PAD_ID = get_special_ids(tok)

In [73]:
# ============================================================
# 1-Extra) Ban-list builders (fight NBSP/whitespace loops)
# ============================================================
def get_nbsp_token_ids(tokenizer: AutoTokenizer) -> List[int]:
    """
    Return ALL token IDs that represent NBSP (U+00A0).
    In byte-level BPE it can be 1+ IDs.
    """
    return tokenizer.encode("\u00A0", add_special_tokens=False)
 
def build_whitespace_banlist(tokenizer: AutoTokenizer) -> List[int]:
    """
    Build a list of token IDs that decode to whitespace-only strings,
    EXCLUDING a normal space ' ' (keep regular spaces).
    """
    ban = []
    for tid in range(tokenizer.vocab_size):
        s = tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
        if s and s.isspace() and s != " ":
            ban.append(tid)
    return ban

# Choose your strategy:
#   A) Only NBSP            → safer/targeted
#   B) All whitespace-only  → stronger (recommended)
BAN_IDS = torch.tensor(build_whitespace_banlist(tok), device=DEVICE, dtype=torch.long)
print(f"EOS_ID={EOS_ID}, PAD_ID={PAD_ID}, #BAN_IDS={BAN_IDS.numel()}")

EOS_ID=50256, PAD_ID=50256, #BAN_IDS=22


In [74]:
print("#BAN_IDS:", BAN_IDS.numel())
# 상위 몇 개 확인
print("Sample banned IDs:", BAN_IDS[:10].tolist())

# 실제로 디코딩해 확인
for tid in BAN_IDS[:5].tolist():
    s = tok.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
    print(tid, repr(s), "isspace?", s.isspace())


#BAN_IDS: 22
Sample banned IDs: [197, 198, 199, 200, 201, 216, 217, 218, 219, 628]
197 '\t' isspace? True
198 '\n' isspace? True
199 '\x0b' isspace? True
200 '\x0c' isspace? True
201 '\r' isspace? True


In [75]:
# ============================================================
# 2) Text I/O helpers (encode/decode/append + normalization)
# ============================================================
def normalize_nbsp_to_space(text: str) -> str:
    """Replace NBSP (U+00A0) with a regular space for cleaner display."""
    return text.replace("\u00A0", " ")

In [76]:
def encode(prompt: str) -> torch.Tensor:
    """Text → token IDs on DEVICE. Shape: [1, T]."""
    return tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]

In [77]:
def decode(ids: torch.Tensor) -> str:
    """Token IDs → text (with NBSP normalized)."""
    txt = tok.decode(ids[0], skip_special_tokens=True)
    return normalize_nbsp_to_space(txt)

In [78]:
def append_token(ids: torch.Tensor, tid: int) -> torch.Tensor:
    """Append single token ID to [1, T] sequence (preserve dtype/device)."""
    t = torch.tensor([[tid]], dtype=ids.dtype, device=ids.device)
    return torch.cat([ids, t], dim=1)

In [79]:
# ============================================================
# 3) Logits utilities (last-step + masking)
# ============================================================
@torch.inference_mode()
def last_step_logits(ids: torch.Tensor) -> torch.Tensor:
    """Return logits for the last position. Shape: [V]."""
    return model(ids).logits[:, -1, :][0]
    """
    1) model(ids)
    Runs a forward pass of the language model on the token IDs ids (shape typically [B, T]).
    2) .logits
    Returns the raw scores (logits) for each position and each vocabulary token.
    Shape: [B, T, V]
    B: batch size
    T: sequence length
    V: vocabulary size (number of next-token candidates)
    3) [:, -1, :]
    Selects:
    all batch items (:),
    the last time step (-1),
    all vocabulary dimensions (:).
    Result shape: [B, V] — per-batch “next-token” logit vectors.
    4) [0]
Takes the first batch item only.
Final shape: [V] — the next-token logits vector for batch item 0.
    """

In [53]:
# 1) Encode a prompt → ids: [1, T]
prompt = "In a distant future, "
inps = tok(prompt, return_tensors="pt").to(model.device)
ids = inps["input_ids"]  # shape: [1, T]

# 2) Take last-step logits for the first (and only) batch item
vec = model(**inps).logits[:, -1, :][0]   # == model(ids).logits[:, -1, :][0]
print("ids.shape =", tuple(ids.shape))     # expect: (1, T)
print("vec.shape =", tuple(vec.shape))     # expect: (V,)  i.e., [V]


ids.shape = (1, 6)
vec.shape = (50257,)


In [58]:
"""In-place: forbid EOS by setting its logit to -inf.
Make probability zero: If you set the EOS logit to −∞ before softmax, the EOS probability becomes exactly 0 after softmax, so both sampling and greedy can never select it.
Prevent early termination: GPT-2 models often pick EOS in the very first steps and stop immediately. By banning EOS only for the first N steps (e.g., ban_eos_steps=8), you force the model to produce at least some tokens.
Controlled behavior: Using a ban flag lets you apply this only when desired; later you can lift the ban to allow normal termination via EOS.
"""

'In-place: forbid EOS by setting its logit to -inf.\nMake probability zero: If you set the EOS logit to −∞ before softmax, the EOS probability becomes exactly 0 after softmax, so both sampling and greedy can never select it.\nPrevent early termination: GPT-2 models often pick EOS in the very first steps and stop immediately. By banning EOS only for the first N steps (e.g., ban_eos_steps=8), you force the model to produce at least some tokens.\nControlled behavior: Using a ban flag lets you apply this only when desired; later you can lift the ban to allow normal termination via EOS.\n'

In [61]:
def mask_eos_(logits: torch.Tensor, ban: bool) -> None:
    """In-place: forbid EOS by setting its logit to -inf."""
    if ban and EOS_ID is not None:
        logits[EOS_ID] = float("-inf")

In [80]:
def mask_bad_tokens_(logits: torch.Tensor, ban_ids: Optional[torch.Tensor]) -> None:
    """In-place: forbid any token in ban_ids by setting logits to -inf."""
    if ban_ids is not None and ban_ids.numel() > 0:
        logits[ban_ids] = float("-inf")

def mask_all_(logits: torch.Tensor, *, ban_eos: bool = False, ban_ids: Optional[torch.Tensor] = None) -> None:
    """In-place: apply both EOS mask and ban-list mask."""
    mask_eos_(logits, ban=ban_eos)
    mask_bad_tokens_(logits, ban_ids=ban_ids)

In [85]:
#Hugging Face’s repetition_penalty is a technique that lowers the scores of tokens that have already appeared in the context, reducing repeated words/phrases

In [86]:
# ============================================================
# 4) Repetition controls (logits post-processing)
# ============================================================
def apply_repetition_penalty_(logits: torch.Tensor,
                              ids: torch.Tensor,
                              penalty: float = 1.0) -> None:
    """
    In-place repetition penalty (Hugging Face style):
      - if a token appeared in the context:
           logit > 0 → logit /= penalty
           logit < 0 → logit *= penalty
    """
    if penalty is None or float(penalty) == 1.0:
        return
    seen = torch.unique(ids)  # [N_seen]
    for tid in seen.tolist():
        t = logits[tid]
        logits[tid] = t / penalty if t > 0 else t * penalty

In [87]:
def _banned_tokens_for_no_repeat(ids: torch.Tensor, n: int) -> List[int]:
    """Return tokens that would create a repeated n-gram if chosen next."""
    if not n or n <= 0:
        return []
    seq = ids[0].tolist()
    if len(seq) < n:
        return []
    table = {}
    for i in range(len(seq) - n + 1):
        prefix = tuple(seq[i:i + n - 1])
        nx = seq[i + n - 1]
        table.setdefault(prefix, set()).add(nx)
    cur_prefix = tuple(seq[-(n - 1):])
    return list(table.get(cur_prefix, set()))

In [84]:
def apply_no_repeat_ngram_(logits: torch.Tensor, ids: torch.Tensor, n: int = 0) -> None:
    """In-place: ban tokens that would form a repeated n-gram."""
    if not n or n <= 0:
        return
    banned = _banned_tokens_for_no_repeat(ids, n)
    if banned:
        logits[torch.tensor(banned, device=logits.device, dtype=torch.long)] = float("-inf")

In [88]:
#greedy decoding =greedy decoding picks and appends the highest-probability token (i.e., the argmax) at each step.

In [90]:
"""
    Pick the next token greedily (argmax) *after* applying safety masks and repetition controls.

    Args:
        ids: Tensor of shape [1, T]. The current token sequence (same device/dtype as the model).
    Keyword Args:
        ban_eos: If True, forbid EOS for this step (helps prevent early termination in the first N steps).
        ban_ids: Tensor of token IDs to always forbid (e.g., NBSP/newline-only tokens). Can be None.
        repetition_penalty: Soft penalty (>1.0) applied to tokens that already appeared in `ids`.
                            Typical range: 1.1–1.3. 1.0 disables the effect.
        no_repeat_ngram_size: Hard constraint that prevents repeating n-grams (e.g., 4 or 5).

    Returns:
        int: The ID of the selected next token (greedy argmax after all adjustments).

    Notes:
        - This function assumes helper utilities exist:
            last_step_logits(ids) -> [V]
            mask_all_(logits, ban_eos, ban_ids)   # in-place: sets logits for banned tokens to -inf
            apply_repetition_penalty_(logits, ids, penalty)  # in-place soft downweight
            apply_no_repeat_ngram_(logits, ids, n)           # in-place hard ban for repeated n-grams
        - We clone the logits because we modify them in-place before argmax.
"""

'\n    Pick the next token greedily (argmax) *after* applying safety masks and repetition controls.\n\n    Args:\n        ids: Tensor of shape [1, T]. The current token sequence (same device/dtype as the model).\n    Keyword Args:\n        ban_eos: If True, forbid EOS for this step (helps prevent early termination in the first N steps).\n        ban_ids: Tensor of token IDs to always forbid (e.g., NBSP/newline-only tokens). Can be None.\n        repetition_penalty: Soft penalty (>1.0) applied to tokens that already appeared in `ids`.\n                            Typical range: 1.1–1.3. 1.0 disables the effect.\n        no_repeat_ngram_size: Hard constraint that prevents repeating n-grams (e.g., 4 or 5).\n\n    Returns:\n        int: The ID of the selected next token (greedy argmax after all adjustments).\n\n    Notes:\n        - This function assumes helper utilities exist:\n            last_step_logits(ids) -> [V]\n            mask_all_(logits, ban_eos, ban_ids)   # in-place: sets log

In [91]:
# ============================================================
# 5) Greedy decoding (baseline, masks + repetition controls)
# ============================================================
@torch.inference_mode()
def greedy_next(ids: torch.Tensor,
                *,
                ban_eos: bool = False,
                ban_ids: Optional[torch.Tensor] = BAN_IDS,
                repetition_penalty: float = 1.2,
                no_repeat_ngram_size: int = 4) -> int:
    # 1) Get logits for the last position. Shape: [V].
    #    Clone so we can safely mutate (mask/penalize) without affecting upstream tensors.
    logits = last_step_logits(ids).clone()

    # 2) Apply safety masks:
    #    - ban_eos: optionally forbid EOS this step (logit = -inf)
    #    - ban_ids: forbid whitespace-only or any custom disallowed tokens
    mask_all_(logits, ban_eos=ban_eos, ban_ids=ban_ids)

    # 3) Softly reduce the likelihood of tokens that already appeared in the context.
    apply_repetition_penalty_(logits, ids, penalty=repetition_penalty)

    # 4) Hard-prevent forming a repeated n-gram by banning the offending next tokens (logit = -inf).
    apply_no_repeat_ngram_(logits, ids, n=no_repeat_ngram_size)

    # 5) Greedy selection: pick the token with the highest (adjusted) logit.
    return int(torch.argmax(logits).item())


In [96]:
@torch.inference_mode()
def greedy_generate(prompt: str,
                    max_new_tokens: int = 40,
                    *,
                    ban_eos_steps: int = 8,
                    ban_ids: Optional[torch.Tensor] = BAN_IDS,
                    repetition_penalty: float = 1.2,
                    no_repeat_ngram_size: int = 4) -> str:
    """
    Greedy decoding with early EOS ban + repetition controls.
    - For the first `ban_eos_steps`, EOS is forbidden to avoid immediate stop.
    """
 # 1) Encode the prompt to token IDs on the correct device.
    ids = encode(prompt)

    # 2) Iteratively append up to `max_new_tokens`.
    for t in range(max_new_tokens):
        # 2a) Pick next token greedily after applying masks + repetition controls.
        nxt = greedy_next(ids,
                          ban_eos=(t < ban_eos_steps),
                          ban_ids=ban_ids,
                          repetition_penalty=repetition_penalty,
                          no_repeat_ngram_size=no_repeat_ngram_size)

        # 2b) If EOS is selected *after* the ban window, stop generation.
        if EOS_ID is not None and nxt == EOS_ID and t >= ban_eos_steps:
            break

        # 2c) If EOS is selected *during* the ban window, skip it and continue.
        if EOS_ID is not None and nxt == EOS_ID:
            continue

        # 2d) Otherwise, append the chosen token to the sequence.
        ids = append_token(ids, nxt)

    # 3) Decode token IDs back to text (handles NBSP normalization / skips specials).
    return decode(ids)

In [97]:
# ============================================================
# 6) Sampling building blocks (temperature + top-p nucleus)
# ============================================================
def softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """Apply temperature scaling + softmax."""
    t = max(float(temperature), 1e-6)  #Convert temperature to a float and clamp it to a minimum of 1e-6.
    return torch.softmax(logits / t, dim=-1)
    #Temperature scaling: scale the logits by 1/t and apply softmax, converting them into a probability distribution (sums to 1).

In [98]:
def top_p_filter_indices(probs: torch.Tensor, top_p: Optional[float]) -> torch.Tensor:
    """
    Return indices inside the nucleus (top-p) set.
    Always keep at least the top-1 token.
    """
    V = probs.numel() # vocabulary size
    
    # No filtering: keep all indices [0..V-1]
    if top_p is None or top_p >= 1:
        return torch.arange(V, device=probs.device)
    # Degenerate case: keep only the single most probable token
    if top_p <= 0:
        # Sort tokens by probability (descending)
        sorted_p, sorted_ix = torch.sort(probs, descending=True)
        return sorted_ix[:1]

    # Compute cumulative sum to find the smallest prefix whose mass ≤ top_p 
    sorted_p, sorted_ix = torch.sort(probs, descending=True)
    csum = torch.cumsum(sorted_p, dim=0)
    
    # Keep all tokens up to (and including) the point where cumprob ≤ top_p
    keep = csum <= top_p
    # Safety: always keep the top-1 token, even if top_p is very small
    keep[0] = True
    return sorted_ix[keep]

In [104]:
@torch.inference_mode()
def sample_next_token(ids: torch.Tensor,
                      *,
                      temperature: float = 0.9,
                      top_p: float = 0.95,
                      ban_eos: bool = False,
                      ban_ids: Optional[torch.Tensor] = BAN_IDS,
                      repetition_penalty: float = 1.2,
                      no_repeat_ngram_size: int = 4) -> int:
    """
    Sample ONE next token after applying masks and repetition controls.

    Args:
        ids: Tensor of shape [1, T]. Current token sequence on the model's device.

    Keyword Args:
        temperature: Temperature for softmax smoothing. Lower = sharper; higher = flatter.
        top_p: Nucleus (cumulative) probability threshold. Keep only the smallest prefix
               of tokens whose cumulative prob ≤ top_p (always keep at least top-1).
        ban_eos: If True, forbid EOS this step (prevents premature stopping).
        ban_ids: Tensor of token IDs to always forbid (e.g., whitespace-only tokens). Can be None.
        repetition_penalty: Soft penalty (>1.0) for tokens that already appeared in `ids`.
        no_repeat_ngram_size: Hard ban on forming a repeated n-gram of this size.

    Returns:
        int: The sampled next token ID.
    """
    # 1) Get last-step logits [V] and clone for in-place edits.
    logits = last_step_logits(ids).clone()

    # 2) Safety masks: EOS (optional) + custom banned tokens (e.g., NBSP/newline-only).
    mask_all_(logits, ban_eos=ban_eos, ban_ids=ban_ids)

    # 3) Repetition controls: soft downweight for seen tokens; hard n-gram ban.
    apply_repetition_penalty_(logits, ids, penalty=repetition_penalty)
    apply_no_repeat_ngram_(logits, ids, n=no_repeat_ngram_size)

    # 4) Convert to probabilities with temperature scaling.
    probs = softmax_with_temperature(logits, temperature)

    # 5) Nucleus (top-p) filtering: keep only high-mass prefix; always keep top-1.
    pool_ix = top_p_filter_indices(probs, top_p)
    pool = probs[pool_ix]

    # 6) Renormalize (guard against numerical underflow → fallback to top-1).
    mass = float(pool.sum().item())
    if mass <= 0.0:
        # Degenerate case: fall back to the most probable token in the nucleus.
        pick_local = 0
    else:
        pool = pool / mass
        # 7) Stochastic choice from the filtered, renormalized distribution.
        pick_local = int(torch.multinomial(pool, num_samples=1, replacement=False)[0].item())

    # 8) Map local index back to the original vocab index and return.
    return int(pool_ix[pick_local].item())


In [100]:
# ============================================================
# 7) Propose a short branch (span tokens) via sampling
# ============================================================
@torch.inference_mode()
def propose_branch(ids: torch.Tensor,
                   *,
                   span: int = 3,
                   temperature: float = 0.9,
                   top_p: float = 0.95,
                   ban_ids: Optional[torch.Tensor] = BAN_IDS,
                   ban_eos_first_n: int = 2,
                   repetition_penalty: float = 1.2,
                   no_repeat_ngram_size: int = 4) -> List[int]:
    """
    Propose a short branch of length `span` by iterative sampling.
    - Ban EOS for the first `ban_eos_first_n` samples to avoid immediate stop.
    """
    cur = ids.clone()
    branch: List[int] = []
    for i in range(span):
        ban = (i < ban_eos_first_n)
        t = sample_next_token(cur,
                              temperature=temperature, top_p=top_p,
                              ban_eos=ban, ban_ids=ban_ids,
                              repetition_penalty=repetition_penalty,
                              no_repeat_ngram_size=no_repeat_ngram_size)
        branch.append(t)
        cur = append_token(cur, t)
    return branch

In [30]:
# ============================================================
# 8) Prefix-accept-once (compare sampled vs greedy)
# ============================================================
@torch.inference_mode()
def prefix_accept_once(ids: torch.Tensor,
                       branch: List[int],
                       *,
                       ban_ids: Optional[torch.Tensor] = BAN_IDS,
                       ban_eos_during_accept: bool = False,
                       repetition_penalty: float = 1.2,
                       no_repeat_ngram_size: int = 4) -> Tuple[torch.Tensor, int]:
    """
    Accept sampled tokens while they match the greedy prediction.
    On first mismatch, accept greedy token and STOP.
    Returns: (updated_ids, accepted_count)
    """
    cur = ids.clone()
    accepted = 0
    for i, t in enumerate(branch):
        g = greedy_next(cur,
                        ban_eos=(ban_eos_during_accept and i == 0),
                        ban_ids=ban_ids,
                        repetition_penalty=repetition_penalty,
                        no_repeat_ngram_size=no_repeat_ngram_size)
        if g == t:
            cur = append_token(cur, t)
            accepted += 1
        else:
            cur = append_token(cur, g)
            break
    return cur, accepted

In [31]:
# ============================================================
# 9) Ultra-simple Medusa-like loop (one branch per step)
# ============================================================
@torch.inference_mode()
def medusa_tiny(prompt: str,
                *,
                max_new_tokens: int = 30,
                span: int = 3,
                temperature: float = 0.9,
                top_p: float = 0.95,
                ban_eos_steps: int = 8,
                ban_ids: Optional[torch.Tensor] = BAN_IDS,
                repetition_penalty: float = 1.2,
                no_repeat_ngram_size: int = 4) -> str:
    """
    Minimal Medusa-like generation:
      loop:
        (a) propose a small branch of `span` tokens (sampling)
        (b) prefix-accept against greedy (1st mismatch → greedy, stop)
      stop when new tokens reach `max_new_tokens`
    """
    ids = encode(prompt)
    start_len = ids.shape[1]
    steps = math.ceil(max_new_tokens / max(1, span))
    remaining_ban = max(0, ban_eos_steps)

    for step in range(steps):
        branch = propose_branch(ids,
                                span=span, temperature=temperature, top_p=top_p,
                                ban_ids=ban_ids, ban_eos_first_n=min(span, remaining_ban),
                                repetition_penalty=repetition_penalty,
                                no_repeat_ngram_size=no_repeat_ngram_size)
        ids, _ = prefix_accept_once(ids, branch,
                                    ban_ids=ban_ids,
                                    ban_eos_during_accept=(step == 0 and remaining_ban > 0),
                                    repetition_penalty=repetition_penalty,
                                    no_repeat_ngram_size=no_repeat_ngram_size)
        if ids.shape[1] - start_len >= max_new_tokens:
            break
        remaining_ban = max(0, remaining_ban - span)

    return decode(ids)


In [32]:
# ============================================================
# 10) Quick smoke tests (repr shows spaces clearly)
# ============================================================
if __name__ == "__main__":
    prompts = [
        "In a distant future, ",
        "Long ago, ",
    ]

    print("\n=== Greedy (baseline) ===")
    for p in prompts:
        print("PROMPT:", repr(p))
        print(repr(greedy_generate(p,
                                   max_new_tokens=30,
                                   ban_eos_steps=8,
                                   repetition_penalty=1.2,
                                   no_repeat_ngram_size=4)))
        print("-" * 60)

    print("\n=== Medusa-tiny (span=3) ===")
    for p in prompts:
        print("PROMPT:", repr(p))
        print(repr(medusa_tiny(p,
                               max_new_tokens=30,
                               span=3,
                               temperature=0.9,
                               top_p=0.95,
                               ban_eos_steps=8,
                               repetition_penalty=1.2,
                               no_repeat_ngram_size=4)))
        print("-" * 60)


=== Greedy (baseline) ===
PROMPT: 'In a distant future, '
'In a distant future, vernacular is the most common language in English. It has been used for centuries to describe many of our everyday lives and we are often referred to as'
------------------------------------------------------------
PROMPT: 'Long ago, '
'Long ago, iced tea was a popular drink in the United States. It is now widely used as an alternative to traditional Chinese medicine and has been shown to be effective'
------------------------------------------------------------

=== Medusa-tiny (span=3) ===
PROMPT: 'In a distant future, '
'In a distant future, vernacular is the most common language in English. It has been used for centuries'
------------------------------------------------------------
PROMPT: 'Long ago, '
'Long ago, iced tea was a popular drink in the United States. It is'
------------------------------------------------------------
