In [111]:
#Step 1. Device & Seed

In [112]:
import torch, random, os

# Pick device (Apple Silicon → mps, else cuda if available, else cpu)
DEVICE = ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
          else ("cuda" if torch.cuda.is_available() else "cpu"))
print("✅ device:", DEVICE)

# Reproducibility for sampling (샘플링 재현성)
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

✅ device: mps


In [113]:
#Step 2. Config

In [114]:
from dataclasses import dataclass

@dataclass
class Cfg:
    # Models
    DRAFTER_ID: str = "distilgpt2"  # small, fast
    VERIFIER_ID: str = "gpt2"       # baseline verifier

    # Decoding defaults (Sampling)
    MAX_NEW_TOKENS: int = 30        # 짧고 굳게
    TEMPERATURE: float = 0.8        # 자연스러움↑
    TOP_P: float = 0.9              # 자연스러움↑
    REPETITION_PENALTY: float = 1.2 # 반복 억제 강화
    NO_REPEAT_NGRAM: int = 4        # 반복 억제 강화 (simple check in medusa-lite)

    # Medusa-lite branching
    TOPK_BRANCH: int = 3            # K branches
    DRAFT_SPAN: int = 3             # span length per branch

    # Misc
    DEVICE: str = DEVICE
    DEBUG: bool = False             # True면 브랜치/프리픽스 로그 출력

cfg = Cfg()
cfg


Cfg(DRAFTER_ID='distilgpt2', VERIFIER_ID='gpt2', MAX_NEW_TOKENS=30, TEMPERATURE=0.8, TOP_P=0.9, REPETITION_PENALTY=1.2, NO_REPEAT_NGRAM=4, TOPK_BRANCH=3, DRAFT_SPAN=3, DEVICE='mps', DEBUG=False)

In [115]:
#Models & Tokenizers

In [116]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load tokenizers
drafter_tok  = AutoTokenizer.from_pretrained(cfg.DRAFTER_ID)
verifier_tok = AutoTokenizer.from_pretrained(cfg.VERIFIER_ID)

# Ensure EOS/PAD (gpt2는 pad가 없는 경우 多)
def ensure_eos_pad(tokenizer):
    if tokenizer.eos_token_id is None:
        tokenizer.eos_token = ""
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer.eos_token_id, tokenizer.pad_token_id

EOS_ID, PAD_ID = ensure_eos_pad(verifier_tok)

# Load models
drafter  = AutoModelForCausalLM.from_pretrained(cfg.DRAFTER_ID).to(cfg.DEVICE).eval()
verifier = AutoModelForCausalLM.from_pretrained(cfg.VERIFIER_ID).to(cfg.DEVICE).eval()

print("✅ models ready:", cfg.DRAFTER_ID, "/", cfg.VERIFIER_ID, "| EOS:", EOS_ID, "PAD:", PAD_ID)


✅ models ready: distilgpt2 / gpt2 | EOS: 50256 PAD: 50256


In [117]:
#. Baseline: Greedy Generator

In [118]:
@torch.inference_mode()
def greedy_generate(prompt: str, max_new_tokens: int | None = None) -> str:
    if max_new_tokens is None:
        max_new_tokens = cfg.MAX_NEW_TOKENS

    ctx = verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    ids = ctx["input_ids"]

    for _ in range(max_new_tokens):
        logits = verifier(ids).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])
        if nxt == EOS_ID:
            break
        ids = torch.cat([ids, torch.tensor([[nxt]], device=ids.device)], dim=1)

        # simple early stop on punctuation (늘어짐 방지)
        tail = verifier_tok.decode(ids[0][-12:], skip_special_tokens=True).strip()
        if tail.endswith((".", "!", "?")):
            break

    return verifier_tok.decode(ids[0], skip_special_tokens=True)


In [119]:
#Baseline: Sampling Generator

In [120]:
@torch.inference_mode()
def sample_generate(
    prompt: str,
    max_new_tokens: int | None = None,
    *,
    temperature: float | None = None,
    top_p: float | None = None,
    repetition_penalty: float | None = None,
    no_repeat_ngram_size: int | None = None,
) -> str:
    if max_new_tokens is None:
        max_new_tokens = cfg.MAX_NEW_TOKENS

    temperature = cfg.TEMPERATURE if temperature is None else temperature
    top_p = cfg.TOP_P if top_p is None else top_p
    repetition_penalty = cfg.REPETITION_PENALTY if repetition_penalty is None else repetition_penalty
    no_repeat_ngram_size = cfg.NO_REPEAT_NGRAM if no_repeat_ngram_size is None else no_repeat_ngram_size

    inputs = verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    out_ids = verifier.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
        eos_token_id=EOS_ID,
        pad_token_id=PAD_ID,
    )[0]

    return verifier_tok.decode(out_ids, skip_special_tokens=True)


In [121]:
#Drafter/Verifier Helper 함수

In [122]:
@torch.inference_mode()
def verifier_next_argmax(ids: torch.Tensor) -> int:
    """Get verifier's next-token argmax."""
    logits = verifier(ids).logits[:, -1, :]
    return int(torch.argmax(logits, dim=-1)[0])

@torch.inference_mode()
def drafter_topk_first_tokens(ids: torch.Tensor, k: int) -> list[int]:
    """Get top-k next tokens from drafter."""
    logits = drafter(ids).logits[:, -1, :]
    topk = torch.topk(logits, k=k, dim=-1).indices[0].tolist()
    return topk

@torch.inference_mode()
def drafter_rollout(ids: torch.Tensor, first_token: int, span: int) -> list[int]:
    """
    Fix first_token, then greedy rollout to get total 'span' tokens from the drafter.
    """
    cur = torch.cat([ids, torch.tensor([[first_token]], device=ids.device)], dim=1)
    drafted = [first_token]
    for _ in range(span - 1):
        logits = drafter(cur).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])
        drafted.append(nxt)
        cur = torch.cat([cur, torch.tensor([[nxt]], device=cur.device)], dim=1)
    return drafted


In [123]:
#Medusa-lite Core (drafter → verifier → multi-branch prefix-accept)

In [124]:
@torch.inference_mode()
def medusa_lite_generate(prompt: str, max_new_tokens: int | None = None) -> str:
    """
    Loop:
      1) drafter: propose K branches (each of span M)
      2) verifier: prefix-accept; on first mismatch, correct with verifier token
      3) commit the best (longest accepted prefix; full if perfect match, else prefix+1 correction)
      4) stop by EOS/punctuation/max tokens
    """
    if max_new_tokens is None:
        max_new_tokens = cfg.MAX_NEW_TOKENS

    ctx = verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    ids = ctx["input_ids"]
    committed = 0

    while committed < max_new_tokens:
        # 1) branches
        topk = drafter_topk_first_tokens(ids, cfg.TOPK_BRANCH)
        branches = [drafter_rollout(ids, t0, cfg.DRAFT_SPAN) for t0 in topk]

        # 2) evaluate by verifier (prefix-accept)
        best_prefix_len = -1
        best_branch = None
        best_mismatch_vtok = None

        for b_idx, cand in enumerate(branches):
            cur_ids = ids
            prefix_len = 0
            mismatch_vtok = None

            for j, tok in enumerate(cand):
                v_next = verifier_next_argmax(cur_ids)

                if v_next == EOS_ID:
                    mismatch_vtok = EOS_ID
                    break

                if v_next == tok:
                    prefix_len += 1
                    cur_ids = torch.cat([cur_ids, torch.tensor([[tok]], device=cur_ids.device)], dim=1)
                else:
                    mismatch_vtok = v_next
                    break

            if cfg.DEBUG:
                print(f"[branch {b_idx}] cand={cand[:6]}... prefix_len={prefix_len} mismatch={mismatch_vtok}")

            if prefix_len > best_prefix_len:
                best_prefix_len = prefix_len
                best_branch = cand
                best_mismatch_vtok = mismatch_vtok

            if prefix_len == len(cand):  # perfect match early-exit
                break

        # 3) commit
        if best_prefix_len <= 0:
            # nothing matched from the start → commit verifier's token
            vtok = verifier_next_argmax(ids)
            if cfg.NO_REPEAT_NGRAM and _forms_repeated_ngram(ids, vtok, cfg.NO_REPEAT_NGRAM):
                # if repetition would happen, just skip this token (rare case)
                if cfg.DEBUG:
                    print("[skip] repetition would form; skipping one step")
            else:
                ids = torch.cat([ids, torch.tensor([[vtok]], device=ids.device)], dim=1)
                committed += 1
                if vtok == EOS_ID:
                    break
        else:
            # perfect match → commit all; partial match → commit prefix + correction token
            commit_seq = (
                best_branch
                if best_prefix_len == len(best_branch)
                else best_branch[:best_prefix_len] + ([best_mismatch_vtok] if best_mismatch_vtok is not None else [])
            )

            # optional: n-gram repetition guard (simple)
            filtered = []
            for t in commit_seq:
                if cfg.NO_REPEAT_NGRAM and _forms_repeated_ngram(ids, t, cfg.NO_REPEAT_NGRAM):
                    if cfg.DEBUG:
                        print(f"[filter] dropping tok {t} due to n-gram repeat")
                    continue
                filtered.append(t)
            commit_seq = filtered or commit_seq[:1]  # ensure progress

            ids = torch.cat([ids, torch.tensor([commit_seq], device=ids.device)], dim=1)
            committed += len(commit_seq)

            if commit_seq and commit_seq[-1] == EOS_ID:
                break

        # punctuation-based early stop
        tail = verifier_tok.decode(ids[0][-12:], skip_special_tokens=True).strip()
        if tail.endswith((".", "!", "?")):
            break

    return verifier_tok.decode(ids[0], skip_special_tokens=True)


In [None]:
prompt = "In the distant future, "

print("=== sampling ===")
print(sample_generate(prompt))

print("\n=== greedy (baseline) ===")
print(greedy_generate(prompt))

print("\n=== medusa-lite ===")
print(medusa_lite_generate(prompt))


=== sampling ===


In [69]:
#Medusa-lite flow : drafter → verifier → multi-branch prefix-accept

In [70]:
#Step 1 : 설치 & 버전 확인
#Step 2 : Device 선택 (mps / cuda / cpu 중 하나)
#Step 3 : Seed 고정
#Step 4 : Config 작성
#Step 5~ : Drafter / Verifier 로드 → Prompt 준비 → Draft/Verify 함수들 …

In [71]:
import torch, transformers

In [72]:
# Step 2) Device 선택
import torch

def pick_device():
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"     # 🖥️ 맥북이면 mps
    if torch.cuda.is_available():
        return "cuda"    # 💻 GPU 있으면 cuda
    return "cpu"         # 나머지는 cpu

DEVICE = pick_device()
print("✅ DEVICE =", DEVICE)

✅ DEVICE = mps


In [73]:
assert DEVICE in {"cpu", "cuda", "mps"}
print("OK")

OK


In [74]:
#Seed 고정

In [75]:
import random, torch
def set_seed(seed:int=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print('✅ seed set')

✅ seed set


In [76]:
#Config

In [77]:
from dataclasses import dataclass

@dataclass
class Cfg:
    DRAFTER_ID: str = "distilgpt2"
    VERIFIER_ID: str = "gpt2"
    TOPK_BRANCH: int = 3
    DRAFT_SPAN: int = 3
    MAX_NEW_TOKENS: int = 30     # 40 → 30으로
    TEMPERATURE: float = 0.8
    TOP_P: float = 0.9
    REPETITION_PENALTY: float = 1.2
    NO_REPEAT_NGRAM: int = 4
    DEVICE: str = DEVICE

cfg = Cfg()

In [78]:
from transformers import AutoTokenizer, AutoModelForCausalLM

drafter_tok = AutoTokenizer.from_pretrained(cfg.DRAFTER_ID)
drafter     = AutoModelForCausalLM.from_pretrained(cfg.DRAFTER_ID).to(cfg.DEVICE)
drafter.eval()

print('🔸 drafter ready')

🔸 drafter ready


In [79]:
ok = (drafter_tok is not None) and (drafter is not None)
print('drafter loaded? ->', ok)

drafter loaded? -> True


In [80]:
from transformers import AutoTokenizer, AutoModelForCausalLM

verifier_tok = AutoTokenizer.from_pretrained(cfg.VERIFIER_ID)
verifier     = AutoModelForCausalLM.from_pretrained(cfg.VERIFIER_ID).to(cfg.DEVICE)
verifier.eval()

print('🔸 verifier ready')

🔸 verifier ready


In [81]:
ok = (verifier_tok is not None) and (verifier is not None)
print('verifier loaded? ->', ok)

verifier loaded? -> True


In [82]:
#Prompt & Context preparation# Prompt & Context
prompt = "In a distant future, a small crew of explorers discovers "

# drafter 토크나이저로 인코딩 + DEVICE 올리기
ctx = drafter_tok(prompt, return_tensors="pt").to(cfg.DEVICE)

# input_ids만 따로 꺼내기
input_ids = ctx["input_ids"]

print("context ok?", ctx is not None, "| shape:", input_ids.shape)

context ok? True | shape: torch.Size([1, 12])


In [83]:
#Draft 한 토큰 생성 함수

In [84]:
import torch
@torch.inference_mode()
def draft_one_token(model, ids, temperature:float=1.0):
    # 1) 마지막 토큰 위치의 logits 꺼내기
    logits = model(ids).logits[:, -1, :]

    # 2) softmax → 확률 분포
    probs = torch.softmax(logits / temperature, dim=-1)

    # 3) 확률 분포에서 한 개 토큰 뽑기
    next_id = torch.multinomial(probs, num_samples=1)

    return next_id

In [85]:
 # 실제로 실행해보기
sample_id = draft_one_token(drafter, input_ids, 0.8)
print("샘플링된 토큰 ID:", sample_id)
print("토큰 문자열:", drafter_tok.decode(sample_id[0]))

샘플링된 토큰 ID: tensor([[2575]], device='mps:0')
토큰 문자열: urch


In [86]:
#멀티-브랜치 Draft 함수

In [87]:
import torch

@torch.inference_mode()
def drafter_propose(ids, topk:int, span:int, temperature:float):
    branches = []
    for _ in range(topk):
        cur = ids.clone()
        branch = []
        for __ in range(span):
            nxt = draft_one_token(drafter, cur, temperature)
            cur = torch.cat([cur, nxt], dim=1)
            branch.append(int(nxt[0,0]))
        branches.append(branch)
    return branches

In [88]:
b = drafter_propose(input_ids, cfg.TOPK_BRANCH, cfg.DRAFT_SPAN, cfg.TEMPERATURE)
print(len(b), 'branches; span len:', len(b[0]) if b else None)
print('예시 브랜치 1:', b[0])

3 branches; span len: 3
예시 브랜치 1: [1849, 286, 262]


In [89]:
#Verifier: 한 토큰 예측(greedy)

In [90]:
@torch.inference_mode()
def verifier_next_token(ids) -> int:
    logits = verifier(ids).logits[:, -1, :]
    pred = int(torch.argmax(logits, dim=-1)[0])
    return pred

In [91]:
vid = verifier_next_token(input_ids)
print('pred id =', vid, '| token =', verifier_tok.decode([vid]))

pred id = 1849 | token =  


In [92]:
#Prefix-Accept (mismatch까지)

In [93]:
from typing import List, Tuple
import torch
@torch.inference_mode()
def accept_until_mismatch(context_ids, branch_tokens:List[int]) -> Tuple[torch.Tensor, List[int], bool]:
    ids = context_ids.clone()
    accepted = []
    mismatched = False
    for tid in branch_tokens:
        pred = verifier_next_token(ids)
        if pred == tid:
            ids = torch.cat([ids, torch.tensor([[tid]], device=ids.device)], dim=1)
            accepted.append(tid)
        else:
            ids = torch.cat([ids, torch.tensor([[pred]], device=ids.device)], dim=1)
            mismatched = True
            break
    return ids, accepted, mismatched

In [94]:
# 방금 만든 브랜치들 중 첫 번째를 검사해보기
new_ids, accepted, mism = accept_until_mismatch(input_ids, b[0])
print('accepted len:', len(accepted), '| mismatched?', mism)
print('새 길이:', new_ids.shape[1], '| 추가된 토큰 수:', new_ids.shape[1] - input_ids.shape[1])

accepted len: 1 | mismatched? True
새 길이: 14 | 추가된 토큰 수: 2


In [95]:
#Orchestrator (medusa_generate)

In [96]:
import math
import torch

In [97]:
#Branch 점수 함수

In [98]:
def score_branch(accepted, mismatched):
    # prefix-accept된 토큰 수 - mismatch 패널티
    return len(accepted) - (1 if mismatched else 0)

# ✔️ 체크
print(score_branch([1,2,3], False))  # 3
print(score_branch([1,2], True))     # 1 (2 - 1)

3
1


In [99]:
#Prompt → 토큰화

In [100]:
@torch.inference_mode()
def encode_prompt(prompt: str):
    ctx = drafter_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    return ctx["input_ids"]

In [101]:
# ✔️ 체크
ids = encode_prompt("In a distant future, ")
print("ids.shape:", ids.shape)

ids.shape: torch.Size([1, 6])


In [102]:
#한 스텝 수행(multi-branch→검증→최고 점수 채택)

In [103]:
@torch.inference_mode()
def medusa_step(ids, topk_branch:int, draft_span:int, temperature:float):
    branches = drafter_propose(ids, topk_branch, draft_span, temperature)
    best_score = -10**9
    best_ids = None
    for br in branches:
        new_ids, accepted, mism = accept_until_mismatch(ids, br)
        s = score_branch(accepted, mism)
        if s > best_score:
            best_score, best_ids = s, new_ids
    return best_ids

In [104]:
ids2 = medusa_step(ids, cfg.TOPK_BRANCH, cfg.DRAFT_SPAN, cfg.TEMPERATURE)
print("before:", ids.shape[1], "→ after:", ids2.shape[1])

before: 6 → after: 7


In [105]:
#Orchestrator

In [106]:
@torch.inference_mode()
def medusa_generate(prompt:str,
                    max_new_tokens:int=None,
                    topk_branch:int=None,
                    draft_span:int=None,
                    temperature:float=None) -> str:
    if max_new_tokens is None: max_new_tokens = cfg.MAX_NEW_TOKENS
    if topk_branch   is None: topk_branch   = cfg.TOPK_BRANCH
    if draft_span    is None: draft_span    = cfg.DRAFT_SPAN
    if temperature   is None: temperature   = cfg.TEMPERATURE

    ids = encode_prompt(prompt)
    start_len = ids.shape[1]

    steps = math.ceil(max_new_tokens / draft_span)
    for _ in range(steps):
        ids = medusa_step(ids, topk_branch, draft_span, temperature)
        if ids.shape[1] - start_len >= max_new_tokens:
            break

    return drafter_tok.decode(ids[0], skip_special_tokens=True)


In [107]:
out = medusa_generate("In a distant future, ", 40)
print(out)

In a distant future,  the world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place


In [108]:
# Step 13) Greedy Baseline

In [109]:
@torch.inference_mode()
def greedy_generate(prompt: str, max_new_tokens: int = None) -> str:
    if max_new_tokens is None:
        max_new_tokens = cfg.MAX_NEW_TOKENS

    # verifier 기준으로 생성 (비교군)
    ctx = verifier_tok(prompt, return_tensors="pt").to(cfg.DEVICE)
    ids = ctx["input_ids"]

    for _ in range(max_new_tokens):
        logits = verifier(ids).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])  # greedy
        ids = torch.cat([ids, torch.tensor([[nxt]], device=ids.device)], dim=1)

    return verifier_tok.decode(ids[0], skip_special_tokens=True)

In [110]:
txt = greedy_generate("In a distant future, ", 40)
print(txt)

In a distant future,  the world would be a better place.
The world would be a better place.
The world would be a better place.
The world would be a better place.
The world would


In [None]:
#1) A/B 속도·텍스트 비교 셀