In [1]:
# medusa_super_tiny.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "distilgpt2"
device = ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
          else ("cuda" if torch.cuda.is_available() else "cpu"))

tok = AutoTokenizer.from_pretrained(MODEL_ID)
if tok.eos_token_id is None: tok.eos_token = ""
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(device).eval()
model.config.use_cache = True
EOS_ID = tok.eos_token_id

def encode(s): return tok(s, return_tensors="pt").to(device)["input_ids"]
def decode(ids): return tok.decode(ids[0], skip_special_tokens=True)

@torch.inference_mode()
def last_logits(ids):  # [V]
    return model(ids).logits[0, -1, :]

@torch.inference_mode()
def greedy_next(ids):  # argmax만
    return int(torch.argmax(last_logits(ids)).item())

@torch.inference_mode()
def softmax_temp(logits, t=0.9):
    t = max(float(t), 1e-6)
    return torch.softmax(logits / t, dim=-1)

@torch.inference_mode()
def top_p_indices(probs, top_p=0.95):
    V = probs.numel()
    if top_p is None or top_p >= 1: return torch.arange(V, device=probs.device)
    sp, sx = torch.sort(probs, descending=True)
    csum = torch.cumsum(sp, dim=0)
    keep = csum <= top_p
    keep[0] = True
    return sx[keep]

@torch.inference_mode()
def sample_next(ids, temperature=0.9, top_p=0.95):
    probs = softmax_temp(last_logits(ids), temperature)
    pool_ix = top_p_indices(probs, top_p)
    pool = probs[pool_ix]
    pool = pool / pool.sum()
    pick_local = torch.multinomial(pool, 1)[0].item()
    return int(pool_ix[pick_local].item())

@torch.inference_mode()
def propose_branch(ids, span=3, temperature=0.9, top_p=0.95):
    cur = ids.clone()
    out = []
    for _ in range(span):
        t = sample_next(cur, temperature, top_p)
        out.append(t)
        cur = torch.cat([cur, torch.tensor([[t]], device=cur.device)], dim=1)
    return out  # list[int]

@torch.inference_mode()
def prefix_accept_once(ids, branch):
    cur = ids.clone()
    accepted = 0
    for t in branch:
        g = greedy_next(cur)
        if g == t:  # 일치 → 수락
            cur = torch.cat([cur, torch.tensor([[t]], device=cur.device)], dim=1)
            accepted += 1
        else:       # 불일치 → 그리디 채택 후 중단
            cur = torch.cat([cur, torch.tensor([[g]], device=cur.device)], dim=1)
            break
    return cur, accepted

@torch.inference_mode()
def medusa_tiny(prompt, max_new_tokens=30, span=3, temperature=0.9, top_p=0.95):
    ids = encode(prompt)
    start = ids.shape[1]
    # 목표 길이 채울 때까지 반복 (진짜 최소형)
    while ids.shape[1] - start < max_new_tokens:
        branch = propose_branch(ids, span=span, temperature=temperature, top_p=top_p)
        ids, _ = prefix_accept_once(ids, branch)
    return decode(ids)

if __name__ == "__main__":
    prompt = "In a distant future, "
    # 1) Greedy만
    ids = encode(prompt)
    for _ in range(30):
        nxt = greedy_next(ids)
        if EOS_ID is not None and nxt == EOS_ID: break
        ids = torch.cat([ids, torch.tensor([[nxt]], device=ids.device)], dim=1)
    print("=== Greedy ===")
    print(decode(ids), "\n")

    # 2) Medusa-tiny (초간단)
    print("=== Medusa-tiny ===")
    print(medusa_tiny(prompt, max_new_tokens=30, span=3, temperature=0.9, top_p=0.95))


=== Greedy ===
In a distant future,   the world is a place where the world is a place where the world is a place where the world is a place where the world is a place 

=== Medusa-tiny ===
In a distant future,   the world is a place where the world is a place where the world is a place where the world is a place where the world is a place where the
