In [484]:
#Medusa-lite flow : drafter → verifier → multi-branch prefix-accept

In [486]:
from dataclasses import dataclass
import torch

In [487]:
# Step 2) Device 선택

def pick_device():
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"     # 맥북이면 mps
    if torch.cuda.is_available():
        return "cuda"    # GPU 있으면 cuda
    return "cpu"         # 나머지는 cpu

DEVICE = pick_device()
print("✅ DEVICE =", DEVICE)

✅ DEVICE = mps


In [488]:
assert DEVICE in {"cpu", "cuda", "mps"}
print("OK")

OK


In [489]:
#Seed 고정

In [490]:
from dataclasses import dataclass
import torch, random

# Device 자동 선택
DEVICE = ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
          else ("cuda" if torch.cuda.is_available() else "cpu"))
print("✅ device:", DEVICE)

# 재현성(샘플링 안 쓰니까 큰 영향은 없지만 고정)
random.seed(42); torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)


✅ device: mps


In [491]:
#Config

In [492]:
from dataclasses import dataclass

@dataclass
class Cfg:
    DRAFTER_ID: str = "distilgpt2"
    VERIFIER_ID: str = "gpt2-medium"

    MAX_NEW_TOKENS: int = 30

    TEMPERATURE: float = 0.8   # 🔹 추가
    TOP_P: float = 0.9         # 🔹 추가

    REPETITION_PENALTY: float = 1.3
    NO_REPEAT_NGRAM: int = 5

    TOPK_BRANCH: int = 4
    DRAFT_SPAN: int = 3

    DEVICE: str = DEVICE
    DEBUG: bool = False

cfg = Cfg()

In [493]:
#loading Draft model (Tokenizer-> Model)

In [494]:
from transformers import AutoTokenizer, AutoModelForCausalLM

drafter_tok  = AutoTokenizer.from_pretrained(cfg.DRAFTER_ID)
verifier_tok = AutoTokenizer.from_pretrained(cfg.VERIFIER_ID)

# gpt2 계열은 eos/pad가 비어있는 경우 多 → 보정
if verifier_tok.eos_token_id is None:
    verifier_tok.eos_token = ""
if verifier_tok.pad_token_id is None:
    verifier_tok.pad_token = verifier_tok.eos_token

EOS_ID = verifier_tok.eos_token_id

drafter  = AutoModelForCausalLM.from_pretrained(cfg.DRAFTER_ID).to(cfg.DEVICE).eval()
verifier = AutoModelForCausalLM.from_pretrained(cfg.VERIFIER_ID).to(cfg.DEVICE).eval()

# 캐시 사용 활성화(기본 True 이지만 명시)
drafter.config.use_cache  = True
verifier.config.use_cache = True

print("✅ models ready:", cfg.DRAFTER_ID, "/", cfg.VERIFIER_ID)


✅ models ready: distilgpt2 / gpt2-medium


In [495]:
#Prompt & Context preparation# Prompt & Context
prompt = "In a distant future, a small crew of explorers discovers "

# drafter 토크나이저로 인코딩 + DEVICE 올리기
ctx = drafter_tok(prompt, return_tensors="pt").to(cfg.DEVICE)

# input_ids만 따로 꺼내기
input_ids = ctx["input_ids"]

print("context ok?", ctx is not None, "| shape:", input_ids.shape)

context ok? True | shape: torch.Size([1, 12])


In [496]:
#Draft 한 토큰 생성 함수

In [497]:
@torch.inference_mode()
def drafter_sample_first_tokens_basic(model, ids, k: int, temperature: float = 0.8):
    logits = model(ids).logits[:, -1, :]
    probs  = torch.softmax(logits / max(temperature, 1e-6), dim=-1)[0]
    k = min(k, probs.numel())
    picks = torch.multinomial(probs, num_samples=k, replacement=False)
    return [int(i) for i in picks]

In [498]:
tid = 1849
print("token str (repr):", repr(drafter_tok.decode([tid])))
print("gpt2 piece:", drafter_tok.convert_ids_to_tokens([tid])[0])
print("is space?", drafter_tok.decode([tid]).isspace())

token str (repr): '\xa0'
gpt2 piece: Âł
is space? True


In [499]:
#멀티-브랜치 Draft 함수

In [500]:
@torch.inference_mode()
def drafter_sample_first_tokens(model, ids, k: int, temperature=0.8, top_p=0.9) -> list[int]:
    logits = model(ids).logits[:, -1, :]
    probs  = torch.softmax(logits / max(temperature, 1e-6), dim=-1)[0]

    # nucleus(top-p) 샘플링
    if top_p is not None:
        sorted_p, sorted_ix = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_p, dim=0)
        keep = cumsum <= top_p
        keep[0] = True
        pool_ix = sorted_ix[keep]
        pool_p  = probs[pool_ix] / probs[pool_ix].sum()
        picks   = torch.multinomial(pool_p, num_samples=min(k, pool_ix.numel()), replacement=False)
        return [int(pool_ix[i]) for i in picks]
    else:
        picks = torch.multinomial(probs, num_samples=k, replacement=False)
        return [int(i) for i in picks]

@torch.inference_mode()
def drafter_rollout(ids, first_tok: int, span: int) -> list[int]:
    cur = torch.cat([ids, torch.tensor([[first_tok]], device=ids.device)], dim=1)
    seq = [first_tok]
    for _ in range(span - 1):
        logits = drafter(cur).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])
        seq.append(nxt)
        cur = torch.cat([cur, torch.tensor([[nxt]], device=cur.device)], dim=1)
    return seq

@torch.inference_mode()
def drafter_propose(ids, k: int, span: int, temperature=0.8, top_p=0.9) -> list[list[int]]:
    firsts = drafter_sample_first_tokens(drafter, ids, k, temperature, top_p)
    branches = [drafter_rollout(ids, f, span) for f in firsts]
    return branches
@torch.inference_mode()
def drafter_sample_first_tokens(model, ids, k: int, temperature=0.8, top_p=0.9) -> list[int]:
    logits = model(ids).logits[:, -1, :]
    probs  = torch.softmax(logits / max(temperature, 1e-6), dim=-1)[0]

    # nucleus(top-p) 샘플링
    if top_p is not None:
        sorted_p, sorted_ix = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_p, dim=0)
        keep = cumsum <= top_p
        keep[0] = True
        pool_ix = sorted_ix[keep]
        pool_p  = probs[pool_ix] / probs[pool_ix].sum()
        picks   = torch.multinomial(pool_p, num_samples=min(k, pool_ix.numel()), replacement=False)
        return [int(pool_ix[i]) for i in picks]
    else:
        picks = torch.multinomial(probs, num_samples=k, replacement=False)
        return [int(i) for i in picks]

@torch.inference_mode()
def drafter_rollout(ids, first_tok: int, span: int) -> list[int]:
    cur = torch.cat([ids, torch.tensor([[first_tok]], device=ids.device)], dim=1)
    seq = [first_tok]
    for _ in range(span - 1):
        logits = drafter(cur).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])
        seq.append(nxt)
        cur = torch.cat([cur, torch.tensor([[nxt]], device=cur.device)], dim=1)
    return seq

@torch.inference_mode()
def drafter_propose(ids, k: int, span: int, temperature=0.8, top_p=0.9) -> list[list[int]]:
    firsts = drafter_sample_first_tokens(drafter, ids, k, temperature, top_p)
    branches = [drafter_rollout(ids, f, span) for f in firsts]
    return branches


In [501]:
#Verifier: 한 토큰 예측(greedy)

In [502]:
@torch.inference_mode()
def verifier_next_token(ids) -> int:
    """verifier로 다음 토큰 id 하나 예측 (greedy)"""
    logits = verifier(ids).logits[:, -1, :]
    return int(torch.argmax(logits, dim=-1)[0])

In [503]:
def pretty_token(tokenizer, tid: int):
    """토큰 id를 여러 방식으로 표현"""
    s_decode = tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
    s_token  = tokenizer.convert_ids_to_tokens([tid])[0]
    try:
        s_fixed = s_token.encode("latin1").decode("utf-8")
    except Exception:
        s_fixed = s_token
    return {
        "id": tid,
        "decode_repr": repr(s_decode),   # 사람이 안 보이는 공백도 확인
        "token_repr": repr(s_token),     # BPE 토큰 스트링
        "token_fixed": repr(s_fixed),    # 모지바케 보정 시도
        "codepoints": [hex(ord(c)) for c in s_decode],
        "bytes": list(s_decode.encode("utf-8")),
    }


In [504]:
@torch.inference_mode()
def next_human_token(ids, tokenizer, tries=10):
    """보이는 토큰이 나올 때까지 최대 tries번 예측"""
    cur = ids.clone()
    for _ in range(tries):
        tid = verifier_next_token(cur)
        s = tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
        if any(ch.isprintable() and not ch.isspace() for ch in s):
            return tid, s
        cur = torch.cat([cur, torch.tensor([[tid]], device=ids.device)], dim=1)
    return tid, s

In [505]:
vid = verifier_next_token(input_ids)
info = pretty_token(verifier_tok, vid)

print("예측 토큰 정보:", info)

t2, s2 = next_human_token(input_ids, verifier_tok)
print("다음 보이는 토큰:", t2, repr(s2))


예측 토큰 정보: {'id': 488, 'decode_repr': "'ich'", 'token_repr': "'ich'", 'token_fixed': "'ich'", 'codepoints': ['0x69', '0x63', '0x68'], 'bytes': [105, 99, 104]}
다음 보이는 토큰: 488 'ich'


In [506]:
#Prefix-Accept (mismatch까지)

In [507]:
from typing import List, Tuple
import torch
@torch.inference_mode()
def accept_until_mismatch(context_ids, branch_tokens:List[int]) -> Tuple[torch.Tensor, List[int], bool]:
    ids = context_ids.clone()
    accepted = []
    mismatched = False
    for tid in branch_tokens:
        pred = verifier_next_token(ids)
        if pred == tid:
            ids = torch.cat([ids, torch.tensor([[tid]], device=ids.device)], dim=1)
            accepted.append(tid)
        else:
            ids = torch.cat([ids, torch.tensor([[pred]], device=ids.device)], dim=1)
            mismatched = True
            break
    return ids, accepted, mismatched

In [508]:
# 방금 만든 브랜치들 중 첫 번째를 검사해보기
new_ids, accepted, mism = accept_until_mismatch(input_ids, b[0])
print('accepted len:', len(accepted), '| mismatched?', mism)
print('새 길이:', new_ids.shape[1], '| 추가된 토큰 수:', new_ids.shape[1] - input_ids.shape[1])

accepted len: 0 | mismatched? True
새 길이: 13 | 추가된 토큰 수: 1


In [509]:
#Orchestrator (medusa_generate)

In [510]:
import math
import torch

In [511]:
#Branch 점수 함수

In [512]:
def score_branch(accepted, mismatched):
    # prefix-accept된 토큰 수 - mismatch 패널티
    return len(accepted) - (1 if mismatched else 0)

# ✔️ 체크
print(score_branch([1,2,3], False))  # 3
print(score_branch([1,2], True))     # 1 (2 - 1)

3
1


In [513]:
#Prompt → 토큰화

In [514]:
@torch.inference_mode()
def encode_prompt(prompt: str):
    ctx = drafter_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    return ctx["input_ids"]

In [515]:
# ✔️ 체크
ids = encode_prompt("In a distant future, ")
print("ids.shape:", ids.shape)

ids.shape: torch.Size([1, 6])


In [516]:
#한 스텝 수행(multi-branch→검증→최고 점수 채택)

In [517]:
@torch.inference_mode()
def medusa_step(ids, topk_branch: int, draft_span: int, temperature: float):
    # ↑ more diverse branches: stronger sampling
    branches = drafter_propose(
        ids,
        topk_branch,
        draft_span,
        temperature=max(0.9, float(temperature)),  # ensure ≥ 0.9
        top_p=0.95
    )

    best_score = -10**9
    best_ids = None
    for br in branches:
        new_ids, accepted, mism = accept_until_mismatch(ids, br)
        s = score_branch(accepted, mism)
        if s > best_score:
            best_score, best_ids = s, new_ids

    return best_ids


In [518]:
ids2 = medusa_step(ids, cfg.TOPK_BRANCH, cfg.DRAFT_SPAN, cfg.TEMPERATURE)
print("before:", ids.shape[1], "→ after:", ids2.shape[1])

before: 6 → after: 8


In [519]:
#Orchestrator

In [520]:
@torch.inference_mode()
def medusa_generate(prompt:str,
                    max_new_tokens:int=None,
                    topk_branch:int=None,
                    draft_span:int=None,
                    temperature:float=None) -> str:
    if max_new_tokens is None: max_new_tokens = cfg.MAX_NEW_TOKENS
    if topk_branch   is None: topk_branch   = cfg.TOPK_BRANCH
    if draft_span    is None: draft_span    = cfg.DRAFT_SPAN
    if temperature   is None: temperature   = cfg.TEMPERATURE

    ids = encode_prompt(prompt)
    start_len = ids.shape[1]

    steps = math.ceil(max_new_tokens / draft_span)
    for _ in range(steps):
        ids = medusa_step(ids, topk_branch, draft_span, temperature)
        if ids.shape[1] - start_len >= max_new_tokens:
            break

    return drafter_tok.decode(ids[0], skip_special_tokens=True)


In [521]:
out = medusa_generate("In a distant future, ", 40)
print(out)

In a distant future,  the world is ruled by a dictator who is obsessed with the idea of controlling the world's resources.  He


In [522]:
# Step 13) Greedy Baseline

In [523]:
@torch.inference_mode()
def greedy_generate(prompt: str, max_new_tokens: int = None) -> str:
    if max_new_tokens is None:
        max_new_tokens = cfg.MAX_NEW_TOKENS

    # verifier 기준으로 생성 (비교군)
    ctx = verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    ids = ctx["input_ids"]

    for _ in range(max_new_tokens):
        logits = verifier(ids).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])  # greedy
        ids = torch.cat([ids, torch.tensor([[nxt]], device=ids.device)], dim=1)

    return verifier_tok.decode(ids[0], skip_special_tokens=True)

In [524]:
txt = greedy_generate("In a distant future, ", 40)
print(txt)

In a distant future,  the world is ruled by a dictator who is obsessed with the idea of controlling the world's resources.  He wants to control the world's resources so that he can rule the world. 


In [525]:
#1) A/B 속도·텍스트 비교 셀

In [526]:
import time

def time_it(fn, *args, **kwargs):
    t0 = time.perf_counter()
    out = fn(*args, **kwargs)
    return out, time.perf_counter() - t0

g_txt, g_t = time_it(greedy_generate, "In a distant future, ", 80)
m_txt, m_t = time_it(medusa_generate, "In a distant future, ", 80)

print("⏱ greedy:", round(g_t, 3), "s")
print("⏱ medusa:", round(m_t, 3), "s")
print("\n--- greedy ---\n", g_txt[:400])
print("\n--- medusa ---\n", m_txt[:400])


⏱ greedy: 4.241 s
⏱ medusa: 7.441 s

--- greedy ---
 In a distant future,  the world is ruled by a dictator who is obsessed with the idea of controlling the world's resources.  He wants to control the world's resources so that he can rule the world.  He wants to control the world's resources so that he can rule the world.  He wants to control the world's resources so that he can rule the world.  He wants to

--- medusa ---
 In a distant future,  the world is ruled by a dictator who is obsessed with the idea of controlling the world's resources.  He wants to control the world's resources so that he can rule the world.  He wants to control the world's resources so that he can rule the world.  He wants to control
