In [59]:
# ============================================================
# 0) Minimal setup: imports, device, and seed (modular)
# ============================================================
import math
import random
from typing import List, Optional, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [60]:
# ---------- 0.1 Device ----------
def pick_device() -> str:
    """Pick best available device: MPS (Apple) → CUDA → CPU."""
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"


DEVICE = pick_device()
print("✅ device:", DEVICE)

✅ device: mps


In [61]:
# ---------- 0.2 Seed (optional, for reproducibility) ----------
def set_seed(seed: Optional[int] = 42) -> None:
    """Set random seeds; pass None to skip."""
    if seed is None:
        return
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(42)  # change to None if you want non-deterministic sampling

In [62]:
# ============================================================
# 1) Tokenizer & Model loaders (GPT-2 family friendly)
# ============================================================
def load_tokenizer(model_id: str) -> AutoTokenizer:
    """
    Load tokenizer and patch EOS/PAD for GPT-2-like models
    (they often lack explicit eos/pad tokens).
    """
    tok = AutoTokenizer.from_pretrained(model_id)
    if tok.eos_token_id is None:
        tok.eos_token = ""         # set an EOS token text
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    return tok


def load_model(model_id: str, device: str) -> AutoModelForCausalLM:
    """Load causal LM on device with KV-cache enabled."""
    model = AutoModelForCausalLM.from_pretrained(model_id).to(device).eval()
    model.config.use_cache = True
    return model


MODEL_ID = "distilgpt2"  # small & fast for demos
tok = load_tokenizer(MODEL_ID)
model = load_model(MODEL_ID, DEVICE)

In [63]:
# ---------- 1.1 Special IDs ----------
def get_special_ids(tokenizer: AutoTokenizer) -> Tuple[Optional[int], Optional[int]]:
    """Return (EOS_ID, PAD_ID) from tokenizer."""
    return tokenizer.eos_token_id, tokenizer.pad_token_id


EOS_ID, PAD_ID = get_special_ids(tok)

In [64]:
# ============================================================
# 1-Extra) Ban-list builders (fight NBSP/whitespace loops)
# ============================================================
def get_nbsp_token_ids(tokenizer: AutoTokenizer) -> List[int]:
    """
    Return ALL token IDs that represent NBSP (U+00A0).
    In byte-level BPE it can be 1+ IDs.
    """
    return tokenizer.encode("\u00A0", add_special_tokens=False)


def build_whitespace_banlist(tokenizer: AutoTokenizer) -> List[int]:
    """
    Build a list of token IDs that decode to whitespace-only strings,
    EXCLUDING a normal space ' ' (keep regular spaces).
    """
    ban = []
    for tid in range(tokenizer.vocab_size):
        s = tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
        if s and s.isspace() and s != " ":
            ban.append(tid)
    return ban


# Choose your strategy:
#   A) Only NBSP            → safer/targeted
#   B) All whitespace-only  → stronger (recommended)
BAN_IDS = torch.tensor(build_whitespace_banlist(tok), device=DEVICE, dtype=torch.long)
print(f"EOS_ID={EOS_ID}, PAD_ID={PAD_ID}, #BAN_IDS={BAN_IDS.numel()}")

EOS_ID=50256, PAD_ID=50256, #BAN_IDS=22


In [65]:
# ============================================================
# 2) Text I/O helpers (encode/decode/append + normalization)
# ============================================================
def normalize_nbsp_to_space(text: str) -> str:
    """Replace NBSP (U+00A0) with a regular space for cleaner display."""
    return text.replace("\u00A0", " ")


def encode(prompt: str) -> torch.Tensor:
    """Text → token IDs on DEVICE. Shape: [1, T]."""
    return tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]


def decode(ids: torch.Tensor) -> str:
    """Token IDs → text (with NBSP normalized)."""
    txt = tok.decode(ids[0], skip_special_tokens=True)
    return normalize_nbsp_to_space(txt)


def append_token(ids: torch.Tensor, tid: int) -> torch.Tensor:
    """Append single token ID to [1, T] sequence (preserve dtype/device)."""
    t = torch.tensor([[tid]], dtype=ids.dtype, device=ids.device)
    return torch.cat([ids, t], dim=1)

In [66]:
# ============================================================
# 3) Logits utilities (last-step + masking)
# ============================================================
@torch.inference_mode()
def last_step_logits(ids: torch.Tensor) -> torch.Tensor:
    """Return logits for the last position. Shape: [V]."""
    return model(ids).logits[:, -1, :][0]


def mask_eos_(logits: torch.Tensor, ban: bool) -> None:
    """In-place: forbid EOS by setting its logit to -inf."""
    if ban and EOS_ID is not None:
        logits[EOS_ID] = float("-inf")


def mask_bad_tokens_(logits: torch.Tensor, ban_ids: Optional[torch.Tensor]) -> None:
    """In-place: forbid any token in ban_ids by setting logits to -inf."""
    if ban_ids is not None and ban_ids.numel() > 0:
        logits[ban_ids] = float("-inf")


def mask_all_(logits: torch.Tensor, *, ban_eos: bool = False, ban_ids: Optional[torch.Tensor] = None) -> None:
    """In-place: apply both EOS mask and ban-list mask."""
    mask_eos_(logits, ban=ban_eos)
    mask_bad_tokens_(logits, ban_ids=ban_ids)

In [67]:
# ============================================================
# 4) Greedy decoding (baseline, EOS-safe + ban-list)
# ============================================================
@torch.inference_mode()
def greedy_next(ids: torch.Tensor,
                *,
                ban_eos: bool = False,
                ban_ids: Optional[torch.Tensor] = BAN_IDS,
                repetition_penalty: float = 1.0,
                no_repeat_ngram_size: int = 0) -> int:
    """
    Argmax with masks + repetition controls.
    """
    logits = last_step_logits(ids).clone()
    # 1) 기본 마스크(EOS/공백 전용 토큰 등)
    mask_all_(logits, ban_eos=ban_eos, ban_ids=ban_ids)
    # 2) 반복 억제
    apply_repetition_penalty_(logits, ids, penalty=repetition_penalty)
    # 3) n-그램 반복 금지
    apply_no_repeat_ngram_(logits, ids, n=no_repeat_ngram_size)
    # 4) 최종 선택
    return int(torch.argmax(logits).item())

    
@torch.inference_mode()
def greedy_generate(prompt: str, max_new_tokens: int = 40, *, ban_eos_steps: int = 8,
                    ban_ids: Optional[torch.Tensor] = BAN_IDS) -> str:
    """
    Plain greedy decoding with early EOS ban + ban-list.
    - For the first `ban_eos_steps`, EOS is forbidden to avoid immediate stop.
    """
    ids = encode(prompt)
    for t in range(max_new_tokens):
        nxt = greedy_next(ids, ban_eos=(t < ban_eos_steps), ban_ids=ban_ids)
        # If EOS appears after ban window, stop
        if EOS_ID is not None and nxt == EOS_ID and t >= ban_eos_steps:
            break
        # If EOS still slips in during the ban window, skip it (rare)
        if EOS_ID is not None and nxt == EOS_ID:
            continue
        ids = append_token(ids, nxt)
    return decode(ids)


In [68]:
# ============================================================
# 5) Sampling building blocks (temperature + top-p nucleus)
# ============================================================
def softmax_with_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """Apply temperature scaling + softmax."""
    t = max(float(temperature), 1e-6)
    return torch.softmax(logits / t, dim=-1)


def top_p_filter_indices(probs: torch.Tensor, top_p: Optional[float]) -> torch.Tensor:
    """
    Return indices inside the nucleus (top-p) set.
    Always keep at least the top-1 token.
    """
    V = probs.numel()
    if top_p is None or top_p >= 1:
        return torch.arange(V, device=probs.device)
    if top_p <= 0:
        sorted_p, sorted_ix = torch.sort(probs, descending=True)
        return sorted_ix[:1]

    sorted_p, sorted_ix = torch.sort(probs, descending=True)
    csum = torch.cumsum(sorted_p, dim=0)
    keep = csum <= top_p
    keep[0] = True
    return sorted_ix[keep]


@torch.inference_mode()
def sample_next_token(ids: torch.Tensor,
                      *,
                      temperature: float = 0.9,
                      top_p: float = 0.95,
                      ban_eos: bool = False,
                      ban_ids: Optional[torch.Tensor] = BAN_IDS) -> int:
    """
    Sample ONE token via temperature + top-p with optional masks.
    """
    logits = last_step_logits(ids).clone()
    mask_all_(logits, ban_eos=ban_eos, ban_ids=ban_ids)

    probs = softmax_with_temperature(logits, temperature)
    pool_ix = top_p_filter_indices(probs, top_p)
    pool_p = probs[pool_ix] / pool_p.sum() if (pool_p := probs[pool_ix]).sum() > 0 else probs[pool_ix]

    pick_local = torch.multinomial(pool_p, num_samples=1, replacement=False)[0]
    return int(pool_ix[pick_local].item())

In [69]:
# ============================================================
# 5.1) Repetition controls (logits post-processing)
# ============================================================

def apply_repetition_penalty_(logits: torch.Tensor,
                              ids: torch.Tensor,
                              penalty: float = 1.0) -> None:
    """
    In-place repetition penalty (HuggingFace 방식):
      - If a token appeared in the context:
          logit > 0 → logit /= penalty
          logit < 0 → logit *= penalty
    """
    if penalty is None or float(penalty) == 1.0:
        return
    seen = torch.unique(ids)  # context에 등장한 토큰들
    for tid in seen:
        t = logits[int(tid)]
        if t > 0:
            logits[int(tid)] = t / penalty
        else:
            logits[int(tid)] = t * penalty


def _get_banned_tokens_no_repeat_ngram(ids: torch.Tensor,
                                       n: int) -> list[int]:
    """
    Return tokens that would create a repeated n-gram if chosen next.
    (batch size 1 가정)
    """
    if n is None or n <= 0:
        return []
    seq = ids[0].tolist()
    if len(seq) < n:
        return []
    # build { (n-1)-gram prefix → set(next_token) } from history
    prefix2next = {}
    for i in range(len(seq) - n + 1):
        prefix = tuple(seq[i:i + n - 1])
        nxt = seq[i + n - 1]
        prefix2next.setdefault(prefix, set()).add(nxt)
    # current (n-1)-gram
    cur_prefix = tuple(seq[-(n - 1):])
    return list(prefix2next.get(cur_prefix, set()))


def apply_no_repeat_ngram_(logits: torch.Tensor,
                           ids: torch.Tensor,
                           n: int = 0) -> None:
    """
    In-place: ban tokens that would form a repeated n-gram.
    """
    if n is None or n <= 0:
        return
    banned = _get_banned_tokens_no_repeat_ngram(ids, n)
    if banned:
        idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
        logits[idx] = float("-inf")


In [70]:
# ============================================================
# 6) Propose a short branch (span tokens) via sampling
# ============================================================
@torch.inference_mode()
def propose_branch(ids: torch.Tensor,
                   *,
                   span: int = 3,
                   temperature: float = 0.9,
                   top_p: float = 0.95,
                   ban_ids: Optional[torch.Tensor] = BAN_IDS,
                   ban_eos_first_n: int = 2) -> List[int]:
    """
    Propose a short branch of length `span` by iterative sampling.
    - Ban EOS for the first `ban_eos_first_n` samples to avoid immediate stop.
    """
    cur = ids.clone()
    branch: List[int] = []
    for i in range(span):
        ban = (i < ban_eos_first_n)
        t = sample_next_token(cur, temperature=temperature, top_p=top_p, ban_eos=ban, ban_ids=ban_ids)
        branch.append(t)
        cur = append_token(cur, t)
    return branch


In [71]:
# ============================================================
# 7) Prefix-accept-once (compare sampled vs greedy)
# ============================================================
@torch.inference_mode()
def prefix_accept_once(ids: torch.Tensor,
                       branch: List[int],
                       *,
                       ban_ids: Optional[torch.Tensor] = BAN_IDS,
                       ban_eos_during_accept: bool = False) -> Tuple[torch.Tensor, int]:
    """
    Accept sampled tokens while they match the greedy prediction.
    On first mismatch, accept greedy token and STOP.
    Returns: (updated_ids, accepted_count)
    """
    cur = ids.clone()
    accepted = 0
    for i, t in enumerate(branch):
        g = greedy_next(cur, ban_eos=(ban_eos_during_accept and i == 0), ban_ids=ban_ids)
        if g == t:                         # match → accept sampled token
            cur = append_token(cur, t)
            accepted += 1
        else:                               # mismatch → accept greedy token and stop
            cur = append_token(cur, g)
            break
    return cur, accepted

In [72]:
# ============================================================
# 8) Ultra-simple Medusa-like loop (one branch per step)
# ============================================================
@torch.inference_mode()
def medusa_tiny(prompt: str,
                *,
                max_new_tokens: int = 30,
                span: int = 3,
                temperature: float = 0.9,
                top_p: float = 0.95,
                ban_eos_steps: int = 8,
                ban_ids: Optional[torch.Tensor] = BAN_IDS) -> str:
    """
    Minimal Medusa-like generation:
      loop:
        (a) propose a small branch of `span` tokens (sampling)
        (b) prefix-accept against greedy (1st mismatch → greedy, stop)
      stop when new tokens reach `max_new_tokens`
    """
    ids = encode(prompt)
    start_len = ids.shape[1]
    steps = math.ceil(max_new_tokens / max(1, span))
    remaining_ban = max(0, ban_eos_steps)

    for step in range(steps):
        branch = propose_branch(ids, span=span, temperature=temperature, top_p=top_p,
                                ban_ids=ban_ids, ban_eos_first_n=min(span, remaining_ban))
        ids, _ = prefix_accept_once(ids, branch, ban_ids=ban_ids,
                                    ban_eos_during_accept=(step == 0 and remaining_ban > 0))
        if ids.shape[1] - start_len >= max_new_tokens:
            break
        remaining_ban = max(0, remaining_ban - span)

    return decode(ids)



In [73]:
# ============================================================
# 9) Quick smoke tests (repr shows spaces clearly)
# ============================================================
if __name__ == "__main__":
    prompts = [
        "In a distant future, ",
        "Long ago, ",
    ]

    print("\n=== Greedy (baseline) ===")
    for p in prompts:
        print("PROMPT:", repr(p))
        print(repr(greedy_generate(p, max_new_tokens=30, ban_eos_steps=8)))
        print("-" * 60)

    print("\n=== Medusa-tiny (span=3) ===")
    for p in prompts:
        print("PROMPT:", repr(p))
        print(repr(medusa_tiny(p, max_new_tokens=30, span=3, temperature=0.9, top_p=0.95)))
        print("-" * 60)


=== Greedy (baseline) ===
PROMPT: 'In a distant future, '
'In a distant future, vernacular is a form of the word “vernacular”. It is a form of the word “vernacular”. It'
------------------------------------------------------------
PROMPT: 'Long ago, '
'Long ago, iced tea was a popular drink in the United States. It was also popular in the United States. It was also popular in the United States. It'
------------------------------------------------------------

=== Medusa-tiny (span=3) ===
PROMPT: 'In a distant future, '
'In a distant future, vernacular is a form of the word “vernacular”. It is'
------------------------------------------------------------
PROMPT: 'Long ago, '
'Long ago, iced tea was a popular drink in the United States. It was also popular'
------------------------------------------------------------
