In [41]:
from __future__ import annotations


import random
from dataclasses import dataclass
from typing import List, Tuple


import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


# =============================
# RUN SWITCHES (원하는 단계만 on)
# =============================
흐름을 다시 정리하면 이렇게 돼요:

STEP0 — 디바이스 선택
(DEVICE 정하고, seed 설정)

STEP1 — 토크나이저 / 모델 준비
(tok, model, EOS_ID 준비)

STEP2 — 유틸 함수
(encode, decode, append_token, last_logits, greedy_next)

➡️ STEP3 — 샘플링 함수
(softmax_temp, top_p_indices, sample_next) ← 지금 말씀하신 부분

STEP4 — 드래프터 (propose_branch)

STEP5 — 프리픽스-어셉트 (prefix_accept_once)

STEP6 — 루프 실행 (medusa_tiny / run_greedy)

STEP7 — 데모 (Greedy vs Medusa 결과 출력)

In [42]:
# =============================
# 0) Config
# =============================
@dataclass
class Cfg:
    MODEL_ID: str = "distilgpt2"
    TEMPERATURE: float = 0.9
    TOP_P: float = 0.95
    SPAN: int = 3
    MAX_NEW_TOKENS: int = 30
    BAN_EOS_FIRST_N: int = 0
    REP_BAN_N: int = 0
    SEED: int | None = 7
    DEBUG: bool = False
cfg = Cfg()

In [43]:
if RUN_STEP_0:
    DEVICE = (
        "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        else ("cuda" if torch.cuda.is_available() else "cpu")
    )

    def set_seed(seed: int | None) -> None:
        if seed is None:
            return
        random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    set_seed(cfg.SEED)
    print(f"[STEP0] DEVICE = {DEVICE}")


[STEP0] DEVICE = mps


In [44]:
# =============================
# STEP 1 — 토크나이저/모델 준비 (eos/pad 보정)
# =============================
if RUN_STEP_1:
    def load_tokenizer(model_id: str):
        tok = AutoTokenizer.from_pretrained(model_id)
        if tok.eos_token_id is None:
            tok.eos_token = ""
        if tok.pad_token_id is None:
            tok.pad_token = tok.eos_token
        return tok

    def load_model(model_id: str, device: str):
        model = AutoModelForCausalLM.from_pretrained(model_id).to(device).eval()
        model.config.use_cache = True
        return model

    tok = load_tokenizer(cfg.MODEL_ID)
    model = load_model(cfg.MODEL_ID, DEVICE)
    EOS_ID = tok.eos_token_id

    print(f"[STEP1] MODEL = {cfg.MODEL_ID}, EOS_ID={EOS_ID}")


[STEP1] MODEL = distilgpt2, EOS_ID=50256


In [45]:
# =============================
# STEP 2 — 유틸 함수 (encode/decode/append/last_logits/greedy_next)
# =============================
if RUN_STEP_2:
    def encode(text: str) -> torch.Tensor:
        """문자열 -> 토큰 IDs [1, T] (DEVICE로 이동)"""
        return tok(text, return_tensors="pt").to(DEVICE)["input_ids"]

    def decode(ids: torch.Tensor) -> str:
        """토큰 IDs -> 문자열 (스페셜 토큰 스킵)"""
        return tok.decode(ids[0], skip_special_tokens=True)

    def append_token(ids: torch.Tensor, token_id: int) -> torch.Tensor:
        """마지막에 단일 토큰 붙이기"""
        t = torch.tensor([[token_id]], device=ids.device)
        return torch.cat([ids, t], dim=1)

    @torch.inference_mode()
    def last_logits(ids: torch.Tensor) -> torch.Tensor:
        """마지막 위치의 로짓 벡터 [V]"""
        return model(ids).logits[0, -1, :]

    @torch.inference_mode()
    def greedy_next(ids: torch.Tensor) -> int:
        """그리디 argmax 토큰 ID"""
        return int(torch.argmax(last_logits(ids)).item())

    print("[STEP2] utils ready: encode/decode/append/last_logits/greedy_next")


[STEP2] utils ready: encode/decode/append/last_logits/greedy_next


In [46]:
# =============================
# STEP 3 — Sampling Function (softmax_temp/top_p_indices/sample_next)
# =============================

In [47]:
# =============================
# STEP 3a — softmax_temp
# =============================
if RUN_STEP_3:
    @torch.inference_mode()
    def softmax_temp(logits: torch.Tensor, temperature: float) -> torch.Tensor:
        t = max(float(temperature), 1e-6)
        return torch.softmax(logits / t, dim=-1)

    print("[STEP3a] ready: softmax_temp")


[STEP3a] ready: softmax_temp


In [48]:
# =============================
# STEP 3b — top_p_indices
# =============================
if RUN_STEP_3:
    @torch.inference_mode()
    def top_p_indices(probs: torch.Tensor, top_p: float) -> torch.Tensor:
        V = probs.numel()
        if top_p is None or top_p >= 1:
            return torch.arange(V, device=probs.device)
        sp, sx = torch.sort(probs, descending=True)
        csum = torch.cumsum(sp, dim=0)
        keep = csum <= top_p
        keep[0] = True
        return sx[keep]

    print("[STEP3b] ready: top_p_indices")

[STEP3b] ready: top_p_indices


In [49]:
# =============================
# STEP 3c — sample_next
# =============================
if RUN_STEP_3:
    @torch.inference_mode()
    def sample_next(ids: torch.Tensor, temperature: float, top_p: float,
                    ban_eos: bool = False, rep_ban_n: int = 0) -> int:
        logits = last_logits(ids)
        probs = softmax_temp(logits, temperature)

        # 반복 억제 (최근 N)
        if rep_ban_n > 0:
            tail = ids[0, -rep_ban_n:].tolist()
            probs[tail] = 0
            s = probs.sum()
            probs = probs if s <= 0 else probs / s
            if s <= 0:
                probs = softmax_temp(logits, temperature)  # fallback

        pool_ix = top_p_indices(probs, top_p)
        pool = probs[pool_ix]
        pool = pool / pool.sum()
        pick_local = int(torch.multinomial(pool, 1)[0].item())
        picked = int(pool_ix[pick_local].item())

        if ban_eos and EOS_ID is not None and picked == EOS_ID:
            # EOS 제외 재샘플 (1회)
            mask = pool_ix != EOS_ID
            if mask.any():
                pool_ix2 = pool_ix[mask]
                pool2 = pool[mask]
                pool2 = pool2 / pool2.sum()
                pick_local = int(torch.multinomial(pool2, 1)[0].item())
                picked = int(pool_ix2[pick_local].item())
        return picked

    print("[STEP3c] ready: sample_next")


[STEP3c] ready: sample_next


In [50]:
# =============================
# STEP 4 — 드래프터 (초초 쪼갬)
# =============================
if RUN_STEP_4:
    @torch.inference_mode()
    def propose_one(cur_ids: torch.Tensor, temperature: float, top_p: float,
                    accepted_so_far: int) -> int:
        """드래프팅 1스텝: ban_eos 여부 결정 → sample_next 한 번 호출"""
        ban_eos = accepted_so_far < cfg.BAN_EOS_FIRST_N
        return sample_next(cur_ids, temperature, top_p,
                           ban_eos=ban_eos, rep_ban_n=cfg.REP_BAN_N)

    @torch.inference_mode()
    def propose_branch(ids: torch.Tensor, span: int, temperature: float, top_p: float,
                       accepted_so_far: int = 0) -> List[int]:
        """span 회 만큼 propose_one 반복"""
        cur = ids.clone()
        out: List[int] = []
        for i in range(span):
            t = propose_one(cur, temperature, top_p, accepted_so_far)
            out.append(t)
            cur = append_token(cur, t)
            if cfg.DEBUG:
                print(f"  [draft {i}] pick={t} ({tok.decode([t])!r})")
        return out

    print("[STEP4] drafter ready: propose_one/propose_branch")


[STEP4] drafter ready: propose_one/propose_branch


In [51]:
# =============================
# STEP 5 — 프리픽스-어셉트 (초초 쪼갬)
# =============================
if RUN_STEP_5:
    @torch.inference_mode()
    def accept_one(cur: torch.Tensor, token: int) -> Tuple[torch.Tensor, bool]:
        """그리디 vs 브랜치 토큰 1개 비교 후 (일치=accept, 불일치=stop)"""
        g = greedy_next(cur)
        if cfg.DEBUG:
            print(f"    compare {token}({tok.decode([token])!r}) vs greedy={g}({tok.decode([g])!r})")
        if g == token:
            return append_token(cur, token), True   # 계속
        else:
            return append_token(cur, g), False      # 불일치 → 종료

    @torch.inference_mode()
    def prefix_accept_once(ids: torch.Tensor, branch: List[int]) -> Tuple[torch.Tensor, int]:
        cur = ids.clone()
        accepted = 0
        for t in branch:
            cur, ok = accept_one(cur, t)
            if ok:
                accepted += 1
            else:
                break
        return cur, accepted

    print("[STEP5] prefix-accept ready: accept_one/prefix_accept_once")


[STEP5] prefix-accept ready: accept_one/prefix_accept_once


In [52]:
# =============================
# STEP 6 — 루프 실행 (초초 쪼갬)
# =============================
if RUN_STEP_6:
    @torch.inference_mode()
    def medusa_loop_step(ids: torch.Tensor, span: int, temperature: float,
                         top_p: float, accepted_total: int, loop_idx: int):
        """한 번의 루프: 브랜치 제안 -> prefix-accept 적용"""
        if cfg.DEBUG:
            print(f"[loop {loop_idx}] cur_len={ids.shape[1]}")
        branch = propose_branch(ids, span, temperature, top_p, accepted_total)
        ids, acc = prefix_accept_once(ids, branch)
        return ids, acc

    @torch.inference_mode()
    def medusa_tiny(prompt: str, max_new_tokens: int, span: int,
                    temperature: float, top_p: float) -> str:
        ids = encode(prompt)
        start = ids.shape[1]
        accepted_total = 0
        loop_idx = 0
        while ids.shape[1] - start < max_new_tokens:
            ids, acc = medusa_loop_step(ids, span, temperature, top_p,
                                        accepted_total, loop_idx)
            accepted_total += acc
            loop_idx += 1
        return decode(ids)

    @torch.inference_mode()
    def run_greedy(prompt: str, max_new_tokens: int) -> str:
        ids = encode(prompt)
        start = ids.shape[1]
        while ids.shape[1] - start < max_new_tokens:
            nxt = greedy_next(ids)
            if EOS_ID is not None and nxt == EOS_ID:
                break
            ids = append_token(ids, nxt)
        return decode(ids)

    print("[STEP6] loops ready: medusa_loop_step/medusa_tiny/run_greedy")


[STEP6] loops ready: medusa_loop_step/medusa_tiny/run_greedy


In [53]:
# =============================
# STEP 7 — 데모 (Greedy vs Medusa 결과)
# =============================
if RUN_STEP_7:
    prompt = "In a distant future, "
    print("\n=== Greedy ===")
    print(run_greedy(prompt, cfg.MAX_NEW_TOKENS), "\n")

    print("=== Medusa-tiny ===")
    print(medusa_tiny(prompt, cfg.MAX_NEW_TOKENS,
                      cfg.SPAN, cfg.TEMPERATURE, cfg.TOP_P))

    # 디버그 보고 싶으면 아래 두 줄로 토글 후 재실행
    # cfg.DEBUG = True
    # print(medusa_tiny(prompt, cfg.MAX_NEW_TOKENS,
    #                   cfg.SPAN, cfg.TEMPERATURE, cfg.TOP_P))



=== Greedy ===
In a distant future,   the world is a place where the world is a place where the world is a place where the world is a place where the world is a place 

=== Medusa-tiny ===
In a distant future,   the world is a place where the world is a place where the world is a place where the world is a place where the world is a place where the
