In [37]:
#Medusa-lite flow : drafter → verifier → multi-branch prefix-accept

In [38]:
import torch, transformers

In [39]:
# Step 2) Device 선택
import torch

def pick_device():
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"     # 🖥️ 맥북이면 mps
    if torch.cuda.is_available():
        return "cuda"    # 💻 GPU 있으면 cuda
    return "cpu"         # 나머지는 cpu

DEVICE = pick_device()
print("✅ DEVICE =", DEVICE)

✅ DEVICE = mps


In [40]:
assert DEVICE in {"cpu", "cuda", "mps"}
print("OK")

OK


In [41]:
#Seed 고정

In [42]:
import random, torch
def set_seed(seed:int=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print('✅ seed set')

✅ seed set


In [43]:
#Config

In [44]:
from dataclasses import dataclass
@dataclass
class Cfg:
    DRAFTER_ID: str = 'distilgpt2'     # 소형 drafter
    VERIFIER_ID: str = 'gpt2'   # 검증용
    TOPK_BRANCH: int = 3
    DRAFT_SPAN: int = 3
    MAX_NEW_TOKENS: int = 80
    TEMPERATURE: float = 0.7
    DEVICE: str = DEVICE

cfg = Cfg()
cfg

Cfg(DRAFTER_ID='distilgpt2', VERIFIER_ID='gpt2', TOPK_BRANCH=3, DRAFT_SPAN=3, MAX_NEW_TOKENS=80, TEMPERATURE=0.7, DEVICE='mps')

In [45]:
#loading Draft model (Tokenizer-> Model)

In [46]:
from transformers import AutoTokenizer, AutoModelForCausalLM

drafter_tok = AutoTokenizer.from_pretrained(cfg.DRAFTER_ID)
drafter     = AutoModelForCausalLM.from_pretrained(cfg.DRAFTER_ID).to(cfg.DEVICE)
drafter.eval()

print('🔸 drafter ready')

🔸 drafter ready


In [47]:
ok = (drafter_tok is not None) and (drafter is not None)
print('drafter loaded? ->', ok)

drafter loaded? -> True


In [48]:
#Verifier load

In [49]:
from transformers import AutoTokenizer, AutoModelForCausalLM

verifier_tok = AutoTokenizer.from_pretrained(cfg.VERIFIER_ID)
verifier     = AutoModelForCausalLM.from_pretrained(cfg.VERIFIER_ID).to(cfg.DEVICE)
verifier.eval()

print('🔸 verifier ready')

🔸 verifier ready


In [50]:
ok = (verifier_tok is not None) and (verifier is not None)
print('verifier loaded? ->', ok)

verifier loaded? -> True


In [51]:
#Prompt & Context preparation# Prompt & Context
prompt = "In a distant future, a small crew of explorers discovers "

# drafter 토크나이저로 인코딩 + DEVICE 올리기
ctx = drafter_tok(prompt, return_tensors="pt").to(cfg.DEVICE)

# input_ids만 따로 꺼내기
input_ids = ctx["input_ids"]

print("context ok?", ctx is not None, "| shape:", input_ids.shape)

context ok? True | shape: torch.Size([1, 12])


In [52]:
#Draft 한 토큰 생성 함수

In [53]:
import torch
@torch.inference_mode()
def draft_one_token(model, ids, temperature:float=1.0):
    # 1) 마지막 토큰 위치의 logits 꺼내기
    logits = model(ids).logits[:, -1, :]

    # 2) softmax → 확률 분포
    probs = torch.softmax(logits / temperature, dim=-1)

    # 3) 확률 분포에서 한 개 토큰 뽑기
    next_id = torch.multinomial(probs, num_samples=1)

    return next_id

In [54]:
 # 실제로 실행해보기
sample_id = draft_one_token(drafter, input_ids, 0.8)
print("샘플링된 토큰 ID:", sample_id)
print("토큰 문자열:", drafter_tok.decode(sample_id[0]))

샘플링된 토큰 ID: tensor([[2575]], device='mps:0')
토큰 문자열: urch


In [55]:
#멀티-브랜치 Draft 함수

In [56]:
# topk개의 브랜치, 각 브랜치마다 span 길이만큼 draft_one_token 반복
import torch

@torch.inference_mode()
def drafter_propose(ids, topk:int, span:int, temperature:float):
    branches = []
    for _ in range(topk):
        cur = ids.clone()
        branch = []
        for __ in range(span):
            nxt = draft_one_token(drafter, cur, temperature)
            cur = torch.cat([cur, nxt], dim=1)
            branch.append(int(nxt[0,0]))
        branches.append(branch)
    return branches

In [57]:
b = drafter_propose(input_ids, cfg.TOPK_BRANCH, cfg.DRAFT_SPAN, cfg.TEMPERATURE)
print(len(b), 'branches; span len:', len(b[0]) if b else None)
print('예시 브랜치 1:', b[0])

3 branches; span len: 3
예시 브랜치 1: [1849, 286, 262]


In [58]:
#Verifier: 한 토큰 예측(greedy)

In [59]:
@torch.inference_mode()
def verifier_next_token(ids) -> int:
    logits = verifier(ids).logits[:, -1, :]
    pred = int(torch.argmax(logits, dim=-1)[0])
    return pred

In [60]:
vid = verifier_next_token(input_ids)
print('pred id =', vid, '| token =', verifier_tok.decode([vid]))

pred id = 1849 | token =  


In [61]:
#Prefix-Accept (mismatch까지)

In [62]:
from typing import List, Tuple
import torch
@torch.inference_mode()
def accept_until_mismatch(context_ids, branch_tokens:List[int]) -> Tuple[torch.Tensor, List[int], bool]:
    ids = context_ids.clone()
    accepted = []
    mismatched = False
    for tid in branch_tokens:
        pred = verifier_next_token(ids)
        if pred == tid:
            ids = torch.cat([ids, torch.tensor([[tid]], device=ids.device)], dim=1)
            accepted.append(tid)
        else:
            ids = torch.cat([ids, torch.tensor([[pred]], device=ids.device)], dim=1)
            mismatched = True
            break
    return ids, accepted, mismatched

In [63]:
# 방금 만든 브랜치들 중 첫 번째를 검사해보기
new_ids, accepted, mism = accept_until_mismatch(input_ids, b[0])
print('accepted len:', len(accepted), '| mismatched?', mism)
print('새 길이:', new_ids.shape[1], '| 추가된 토큰 수:', new_ids.shape[1] - input_ids.shape[1])

accepted len: 1 | mismatched? True
새 길이: 14 | 추가된 토큰 수: 2


In [64]:
#Orchestrator (medusa_generate)

In [65]:
import math
import torch

In [66]:
#Branch 점수 함수

In [67]:
def score_branch(accepted, mismatched):
    # prefix-accept된 토큰 수 - mismatch 패널티
    return len(accepted) - (1 if mismatched else 0)

# ✔️ 체크
print(score_branch([1,2,3], False))  # 3
print(score_branch([1,2], True))     # 1 (2 - 1)

3
1


In [68]:
#Prompt → 토큰화

In [69]:
@torch.inference_mode()
def encode_prompt(prompt: str):
    ctx = drafter_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    return ctx["input_ids"]

In [70]:
# ✔️ 체크
ids = encode_prompt("In a distant future, ")
print("ids.shape:", ids.shape)

ids.shape: torch.Size([1, 6])


In [71]:
#한 스텝 수행(multi-branch→검증→최고 점수 채택)

In [72]:
@torch.inference_mode()
def medusa_step(ids, topk_branch:int, draft_span:int, temperature:float):
    branches = drafter_propose(ids, topk_branch, draft_span, temperature)
    best_score = -10**9
    best_ids = None
    for br in branches:
        new_ids, accepted, mism = accept_until_mismatch(ids, br)
        s = score_branch(accepted, mism)
        if s > best_score:
            best_score, best_ids = s, new_ids
    return best_ids

In [73]:
ids2 = medusa_step(ids, cfg.TOPK_BRANCH, cfg.DRAFT_SPAN, cfg.TEMPERATURE)
print("before:", ids.shape[1], "→ after:", ids2.shape[1])

before: 6 → after: 7


In [74]:
#Orchestrator

In [75]:
@torch.inference_mode()
def medusa_generate(prompt:str,
                    max_new_tokens:int=None,
                    topk_branch:int=None,
                    draft_span:int=None,
                    temperature:float=None) -> str:
    if max_new_tokens is None: max_new_tokens = cfg.MAX_NEW_TOKENS
    if topk_branch   is None: topk_branch   = cfg.TOPK_BRANCH
    if draft_span    is None: draft_span    = cfg.DRAFT_SPAN
    if temperature   is None: temperature   = cfg.TEMPERATURE

    ids = encode_prompt(prompt)
    start_len = ids.shape[1]

    steps = math.ceil(max_new_tokens / draft_span)
    for _ in range(steps):
        ids = medusa_step(ids, topk_branch, draft_span, temperature)
        if ids.shape[1] - start_len >= max_new_tokens:
            break

    return drafter_tok.decode(ids[0], skip_special_tokens=True)


In [76]:
out = medusa_generate("In a distant future, ", 40)
print(out)

In a distant future,  the world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The


In [77]:
# Step 13) Greedy Baseline

In [78]:
@torch.inference_mode()
def greedy_generate(prompt: str, max_new_tokens: int = None) -> str:
    if max_new_tokens is None:
        max_new_tokens = cfg.MAX_NEW_TOKENS

    # verifier 기준으로 생성 (비교군)
    ctx = verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    ids = ctx["input_ids"]

    for _ in range(max_new_tokens):
        logits = verifier(ids).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])  # greedy
        ids = torch.cat([ids, torch.tensor([[nxt]], device=ids.device)], dim=1)

    return verifier_tok.decode(ids[0], skip_special_tokens=True)

In [79]:
txt = greedy_generate("In a distant future, ", 40)
print(txt)

In a distant future,  the world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would


In [80]:
#1) A/B 속도·텍스트 비교 셀

In [81]:
import time

def time_it(fn, *args, **kwargs):
    t0 = time.perf_counter()
    out = fn(*args, **kwargs)
    return out, time.perf_counter() - t0

g_txt, g_t = time_it(greedy_generate, "In a distant future, ", 80)
m_txt, m_t = time_it(medusa_generate, "In a distant future, ", 80)

print("⏱ greedy:", round(g_t, 3), "s")
print("⏱ medusa:", round(m_t, 3), "s")
print("\n--- greedy ---\n", g_txt[:400])
print("\n--- medusa ---\n", m_txt[:400])


⏱ greedy: 5.022 s
⏱ medusa: 7.378 s

--- greedy ---
 In a distant future,  the world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place

--- medusa ---
 In a distant future,  the world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be
