In [28]:
import random, torch
DEVICE = ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
          else ("cuda" if torch.cuda.is_available() else "cpu"))
print("DEVICE =", DEVICE)

def set_seed(seed=7):
    random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(7)

DEVICE = mps


In [29]:
#step1 : Tokenizer

In [30]:
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "distilgpt2"

tok = AutoTokenizer.from_pretrained(MODEL_ID)
if tok.eos_token_id is None: tok.eos_token = ""     # EOS 없으면 빈 문자열로라도 지정
if tok.pad_token_id is None: tok.pad_token = tok.eos_token  # PAD 없으면 EOS 재사용
model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(DEVICE).eval()
model.config.use_cache = True
EOS_ID = tok.eos_token_id

In [31]:
# STEP2 — Utilities: encode/decode/greedy_next

def encode(text: str):
    return tok(text, return_tensors="pt").to(DEVICE)["input_ids"]

def decode(ids):
    return tok.decode(ids[0], skip_special_tokens=True)

@torch.inference_mode()
def greedy_next(ids):
    """가장 점수 높은 토큰 하나 고르기"""
    logits = model(ids).logits[0, -1, :]      # 마지막 위치 로짓
    return int(torch.argmax(logits).item())   # argmax

print("[STEP2] utils ready")

[STEP2] utils ready


In [33]:
prompt = "Artificial intelligence is changing the way humans ,"
ids = encode(prompt)
start = ids.shape[1]

for _ in range(30):
    nxt = greedy_next(ids)
    if EOS_ID is not None and nxt == EOS_ID:
        break
    ids = torch.cat([ids, torch.tensor([[nxt]], device=ids.device)], dim=1)

print("=== Greedy ===")
print(decode(ids))

=== Greedy ===
Artificial intelligence is changing the way humans , and it’s changing the way we interact with other people.


















In [34]:
#Sampling(only temperature)

In [35]:
@torch.inference_mode()
def softmax_temp(logits, temperature=1.0):
    """온도 낮을수록 결정적(<1), 높을수록 랜덤(>1)"""
    t = max(float(temperature), 1e-6)
    return torch.softmax(logits / t, dim=-1)

@torch.inference_mode()
def sample_next_temp_only(ids, temperature=0.9):
    logits = model(ids).logits[0, -1, :]
    probs  = softmax_temp(logits, temperature)
    pick = torch.multinomial(probs, 1)[0].item()  # 확률대로 1개 뽑기
    return int(pick)

In [36]:
#test

In [37]:
ids2 = encode(prompt)
for _ in range(30):
    nxt = sample_next_temp_only(ids2, temperature=0.9)
    if EOS_ID is not None and nxt == EOS_ID: break
    ids2 = torch.cat([ids2, torch.tensor([[nxt]], device=ids2.device)], dim=1)

print("=== Sample (temp only) ===")
print(decode(ids2))

=== Sample (temp only) ===
Artificial intelligence is changing the way humans , and understanding human biology is changing the way humans are prepared for its challenges, a presentation published this week in the journal Current Biology Proceedings of the National Academy


In [38]:
#STEP 5 — nucleus(top-p)

In [39]:
@torch.inference_mode()
def top_p_indices(probs, top_p=0.95):
    if top_p is None or top_p >= 1:  # 무제한
        return torch.arange(probs.numel(), device=probs.device)
    sp, sx = torch.sort(probs, descending=True)
    csum = torch.cumsum(sp, dim=0)
    keep = csum <= top_p
    keep[0] = True
    return sx[keep]

@torch.inference_mode()
def sample_next(ids, temperature=0.9, top_p=0.95):
    logits = model(ids).logits[0, -1, :]
    probs  = softmax_temp(logits, temperature)
    pool_ix = top_p_indices(probs, top_p)
    pool = probs[pool_ix]
    pool = pool / pool.sum()                  # 후보 합=1로 정규화
    pick_local = torch.multinomial(pool, 1)[0].item()
    return int(pool_ix[pick_local].item())


In [40]:
#Drafter

In [41]:
 @torch.inference_mode()
def propose_branch(ids, span=3, temperature=0.9, top_p=0.95):
    cur = ids.clone()
    out = []
    for _ in range(span):
        t = sample_next(cur, temperature, top_p)
        out.append(t)
        cur = torch.cat([cur, torch.tensor([[t]], device=cur.device)], dim=1)
    return out  # list[int]

In [42]:
prompt = "Artificial intelligence is changing the way humans "
ids = encode(prompt)

branch = propose_branch(ids, span=5, temperature=0.9, top_p=0.95)

print("Proposed token IDs:", branch)                # 숫자 리스트
print("Decoded tokens:", [tok.decode([t]) for t in branch])  # 각각 글자로
print("Joined as text:", tok.decode(branch))        # 한 번에 이어붙인 결과

Proposed token IDs: [1849, 10996, 351, 257, 3644]
Decoded tokens: ['\xa0', ' communicate', ' with', ' a', ' computer']
Joined as text:   communicate with a computer


In [44]:
#STEP 7 — prefix-accept

In [46]:
@torch.inference_mode()
def prefix_accept_once(ids, branch):
    cur = ids.clone()
    accepted = 0
    for t in branch:
        # 1) 현재까지의 시퀀스에서 그리디 다음 토큰
        g = greedy_next(cur)
        if g == t:
            # 2) 같으면 제안 토큰 수락
            cur = torch.cat([cur, torch.tensor([[t]], device=cur.device)], dim=1)
            accepted += 1
        else:
            # 3) 다르면 그리디 토큰 붙이고 즉시 중단
            cur = torch.cat([cur, torch.tensor([[g]], device=cur.device)], dim=1)
            break
    return cur, accepted

In [47]:
#STEP 8 — Medusa-tiny

In [50]:
@torch.inference_mode()
def medusa_tiny(prompt, max_new_tokens=30, span=3, temperature=0.9, top_p=0.95):
    ids = encode(prompt)
    start = ids.shape[1]
    steps = 0
    max_steps = max_new_tokens * 3  # 간단 안전장치
    while ids.shape[1] - start < max_new_tokens and steps < max_steps:
        branch = propose_branch(ids, span=span, temperature=temperature, top_p=top_p)
        ids, _ = prefix_accept_once(ids, branch)
        steps += 1
    return decode(ids)

In [49]:
print("=== Medusa-tiny ===")
print(medusa_tiny("In a distant future, ", max_new_tokens=30, span=3, temperature=0.9, top_p=0.95))

=== Medusa-tiny ===
In a distant future,   the world is a place where the world is a place where the world is a place where the world is a place where the world is a place
