In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# 디바이스 자동 선택 (M1/M2 → mps, GPU → cuda, 없으면 cpu)
DEVICE = ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
          else ("cuda" if torch.cuda.is_available() else "cpu"))
print("✅ device:", DEVICE)

✅ device: mps


In [2]:
# drafter = 초안기 (빠르고 단순)
drafter_id = "distilgpt2"
drafter_tok = AutoTokenizer.from_pretrained(drafter_id)
drafter = AutoModelForCausalLM.from_pretrained(drafter_id).to(DEVICE).eval()

# verifier = 검증기 (품질↑)
verifier_id = "gpt2-medium"
verifier_tok = AutoTokenizer.from_pretrained(verifier_id)
verifier = AutoModelForCausalLM.from_pretrained(verifier_id).to(DEVICE).eval()

# eos/pad 보정 (gpt2 계열은 기본이 없음)
if verifier_tok.eos_token_id is None:
    verifier_tok.eos_token = ""
if verifier_tok.pad_token_id is None:
    verifier_tok.pad_token = verifier_tok.eos_token
EOS_ID = verifier_tok.eos_token_id

print("✅ models ready:", drafter_id, "/", verifier_id)


✅ models ready: distilgpt2 / gpt2-medium


In [3]:
@torch.inference_mode()
def draft_one_token(model, ids: torch.Tensor, temperature: float = 1.0) -> int:
    """
    drafter에서 다음 토큰 1개 샘플링.
    ids: shape [1, T]
    """
    logits = model(ids).logits[:, -1, :]        # 마지막 위치 로짓
    probs  = torch.softmax(logits / max(temperature, 1e-6), dim=-1)  # 확률 분포
    next_id = torch.multinomial(probs[0], num_samples=1)             # 샘플링
    return int(next_id.item())


In [4]:
ids = drafter_tok("In the distant future,", return_tensors="pt").to(DEVICE)["input_ids"]
sample_id = draft_one_token(drafter, ids, temperature=0.8)
print("샘플링된 토큰 ID:", sample_id)
print("토큰 문자열:", drafter_tok.decode([sample_id]))


샘플링된 토큰 ID: 262
토큰 문자열:  the


In [5]:
@torch.inference_mode()
def drafter_sample_first_tokens_basic(model, ids, k: int, temperature: float = 0.8):
    logits = model(ids).logits[:, -1, :]
    probs  = torch.softmax(logits / max(temperature, 1e-6), dim=-1)[0]
    k = min(k, probs.numel())
    picks = torch.multinomial(probs, num_samples=k, replacement=False)
    return [int(i) for i in picks]

@torch.inference_mode()
def drafter_rollout_basic(ids, first_tok: int, span: int):
    cur = torch.cat([ids, torch.tensor([[first_tok]], device=ids.device)], dim=1)
    seq = [first_tok]
    for _ in range(span - 1):
        logits = drafter(cur).logits[:, -1, :]
        nxt = int(torch.argmax(logits, dim=-1)[0])
        seq.append(nxt)
        cur = torch.cat([cur, torch.tensor([[nxt]], device=cur.device)], dim=1)
    return seq

@torch.inference_mode()
def drafter_propose_basic(ids, k: int, span: int, temperature: float = 0.8):
    firsts = drafter_sample_first_tokens_basic(drafter, ids, k, temperature)
    return [drafter_rollout_basic(ids, t, span) for t in firsts]


In [6]:
branches = drafter_propose_basic(ids, k=3, span=3, temperature=0.8)
print(len(branches), "branches; 예시:", branches[0])

3 branches; 예시: [356, 389, 287]


In [7]:
#verifier

In [8]:
def verifier_next_basic(ids: torch.Tensor, repetition_penalty: float = 1.3, no_repeat_ngram: int = 5) -> int:
    logits = verifier(ids).logits[:, -1, :].clone()
    V = logits.size(-1)

    # repetition penalty: 이미 나온 토큰 로짓 낮추기
    if repetition_penalty and repetition_penalty != 1.0:
        seen = torch.bincount(ids[0].to(torch.int64), minlength=V).bool()
        logits[:, seen] = logits[:, seen] / repetition_penalty

    # no-repeat n-gram: 마지막 n-1 패턴의 n번째 토큰 차단
    n = int(no_repeat_ngram or 0)
    if n > 1 and ids.size(1) >= n - 1:
        tail = ids[0].tolist()
        prefix = tail[-(n-1):]
        blocked = set()
        for i in range(len(tail) - n + 1):
            if tail[i:i+n-1] == prefix:
                blocked.add(tail[i+n-1])
        if blocked:
            logits[:, list(blocked)] = -1e9

    return int(torch.argmax(logits, dim=-1)[0])


In [9]:
@torch.inference_mode()
def medusa_lite_generate_basic(
    prompt: str,
    *,
    max_new_tokens: int = 30,
    k_branches: int = 4,
    span: int = 3,
    temperature: float = 0.8,
    repetition_penalty: float = 1.3,
    no_repeat_ngram: int = 5,
) -> str:
    # 시작 컨텍스트 (verifier 토크나이저 기준)
    ctx = verifier_tok(prompt, return_tensors="pt").to(DEVICE)
    ids = ctx["input_ids"]
    committed = 0

    while committed < max_new_tokens:
        # 1) drafter가 K개 분기 제안 (각각 길이 span)
        branches = drafter_propose_basic(ids, k=k_branches, span=span, temperature=temperature)
        if not branches:  # (희박) 분기 못 만들면 verifier 1토큰만 커밋
            vtok = verifier_next_basic(ids, repetition_penalty, no_repeat_ngram)
            ids = torch.cat([ids, torch.tensor([[vtok]], device=ids.device)], dim=1)
            committed += 1
            break

        # 2) 각 분기에 대해 prefix-accept 길이 측정
        best_len, best_branch, best_mismatch = -1, None, None
        for cand in branches:
            cur = ids
            pref = 0
            mismatch = None
            for t in cand:
                v_next = verifier_next_basic(cur, repetition_penalty, no_repeat_ngram)
                if v_next == EOS_ID:
                    mismatch = EOS_ID; break
                if v_next == t:
                    pref += 1
                    cur = torch.cat([cur, torch.tensor([[t]], device=cur.device)], dim=1)
                else:
                    mismatch = v_next; break
            if pref > best_len:
                best_len, best_branch, best_mismatch = pref, cand, mismatch
            if pref == len(cand):  # 완전 일치면 즉시 채택
                break

        # 3) 커밋
        if best_len <= 0:
            vtok = verifier_next_basic(ids, repetition_penalty, no_repeat_ngram)
            ids = torch.cat([ids, torch.tensor([[vtok]], device=ids.device)], dim=1)
            committed += 1
            if vtok == EOS_ID:
                break
        else:
            commit_seq = (
                best_branch if best_len == len(best_branch)
                else best_branch[:best_len] + ([best_mismatch] if best_mismatch is not None else [])
            )
            ids = torch.cat([ids, torch.tensor([commit_seq], device=ids.device)], dim=1)
            committed += len(commit_seq)
            if commit_seq and commit_seq[-1] == EOS_ID:
                break

        # 4) 문장부호면 조기 종료(늘어짐 방지)
        tail = verifier_tok.decode(ids[0][-40:], skip_special_tokens=True).strip()
        if tail.endswith((".", "!", "?")):
            break

    return verifier_tok.decode(ids[0], skip_special_tokens=True)


In [10]:
print("=== Medusa-lite ===")
print(medusa_lite_generate_basic(
    "In the distant future, ",
    max_new_tokens=30,
    k_branches=4,
    span=3,
    temperature=0.8,         # drafter 다양성
    repetition_penalty=1.3,  # 반복 억제
    no_repeat_ngram=5
))


=== Medusa-lite ===
In the distant future,      is the   ?
