In [1]:
#프롬프트 토큰화 (base, ctx)
#drafter 제안 블록 (draft_ids)
#verifier 그리디 시퀀스 (greedy_ids)
#수락된 prefix 길이 (accepted_prefix_len)
#누적 acceptance_rate_%

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
#device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
#print("using device:", device)

IndentationError: unexpected indent (186763338.py, line 2)

In [7]:
device = "cpu"
print("device:", device)

device: cpu


In [8]:
#Tokenizer
tok = AutoTokenizer.from_pretrained("distilgpt2")
print("tokenizer loaded")

tokenizer loaded


In [9]:
#Drafter model load
drafter = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
print("drafter model loaded")

drafter model loaded


In [10]:
#verifier
verifier = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
print("verifier model loaded")

verifier model loaded


In [11]:
#prompt
prompt = "In a distant future,"
print("prompt:", prompt)

prompt: In a distant future,


In [12]:
#tokenizer PAD.EOS setting
# distilgpt2 / gpt2 계열은 기본적으로 pad_token이 없음 → eos_token으로 맞춰줌
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

print("pad_token:", tok.pad_token)
print("eos_token:", tok.eos_token)


pad_token: <|endoftext|>
eos_token: <|endoftext|>


In [13]:
#Predict–Verify–Accept (P–V–A)

In [14]:
#2-1. Verifier

In [15]:
import torch

@torch.no_grad()
def greedy_next_token(verifier, input_ids: torch.Tensor) -> torch.Tensor:
    out = verifier(input_ids=input_ids)
    next_id = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)  # shape [1,1]
    return next_id

In [16]:
# prompt를 토크나이즈해서 verifier에 넣어보기
inputs = tok(prompt, return_tensors="pt").to(device)

next_id = greedy_next_token(verifier, inputs["input_ids"])
print("예측된 토큰 ID:", next_id.item())
print("예측된 토큰 문자열:", tok.decode([next_id.item()]))

예측된 토큰 ID: 262
예측된 토큰 문자열:  the


In [17]:
#2-2. draft_block func

In [18]:
import torch

@torch.no_grad()
def draft_block(
    drafter,
    input_ids: torch.Tensor,
    block_size: int = 3,
    do_sample: bool = True,
    top_k: int = 50,
    temperature: float = 0.8,
    pad_token_id: int = None,
    # ↓ 추가 옵션들
    top_p: float = 0.95,
    repetition_penalty: float = 1.1,
    no_repeat_ngram_size: int = 3,
    length_penalty: float = 1.0,
) -> torch.Tensor:
    """
    drafter가 block_size만큼 한 번에 제안.
    반환: 새로 제안된 토큰들만 (프롬프트 부분 제외)
    """
    gen = drafter.generate(
        input_ids,
        attention_mask=(input_ids != tok.pad_token_id),  # pad/eos 구분 명시
        max_new_tokens=block_size,
        do_sample=do_sample,
        top_k=top_k if do_sample else None,
        top_p=top_p if do_sample else None,
        temperature=temperature if do_sample else None,
        pad_token_id=pad_token_id,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
        length_penalty=length_penalty,
    )
    # 프롬프트 길이 이후 부분만 반환
    return gen[:, input_ids.shape[1]:]  # shape: [1, block_size]


In [19]:
ctx = tok(prompt, return_tensors="pt", padding=True).input_ids.to(device)
block = draft_block(
    drafter, ctx,
    block_size=3,
    do_sample=True, top_k=50, temperature=0.8,
    pad_token_id=tok.eos_token_id,
    top_p=0.95, repetition_penalty=1.1, no_repeat_ngram_size=3
)

print("draft token IDs:", block.tolist()[0])
print("draft tokens (decoded):", tok.decode(block[0], skip_special_tokens=True))

draft token IDs: [262, 1692, 3234]
draft tokens (decoded):  the human race


In [20]:
#Prefix-verify-accpet-once

In [23]:
import torch

def pva_step(ctx, block_size=3):
    """
    현재 컨텍스트(ctx)에서 drafter가 block 제안 → verifier로 앞에서부터 비교 →
    일치하는 prefix 수락, 첫 불일치에서 greedy 대체 후 종료.
    returns: new_ctx, accepted_len (이번 스텝에서 수락된 토큰 수)
    """
    model_device = next(drafter.parameters()).device
    ctx = ctx.to(model_device)

    block = draft_block(
        drafter, ctx,
        block_size=block_size,
        do_sample=True, top_k=50, temperature=0.8,
        pad_token_id=tok.eos_token_id,
        top_p=0.95, repetition_penalty=1.1, no_repeat_ngram_size=3
    )

    greedy_seq = []
    gctx = ctx.clone().to(next(verifier.parameters()).device)
    with torch.no_grad():
        for _ in range(block.shape[1]):
            attn = (gctx != tok.pad_token_id).to(gctx.device)
            out = verifier(input_ids=gctx, attention_mask=attn)
            g = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
            greedy_seq.append(int(g.item()))
            gctx = torch.cat([gctx, g], dim=1)

    accepted = 0
    cur = ctx.clone().to(gctx.device)
    for i in range(block.shape[1]):
        d_id = int(block[0, i].item())
        g_id = greedy_seq[i]
        if d_id == g_id:
            cur = torch.cat([cur, block[:, i:i+1]], dim=1)
            accepted += 1
        else:
            g = torch.tensor([[g_id]], device=cur.device)
            cur = torch.cat([cur, g], dim=1)
            break

    return cur, accepted


In [24]:
#context

In [25]:
ctx = tok(prompt, return_tensors="pt", padding=True).input_ids.to(device)
new_ctx, accepted = pva_step(ctx, block_size=3)
print("accepted in this step:", accepted)
print(tok.decode(new_ctx[0], skip_special_tokens=True)[:200])

accepted in this step: 0
In a distant future, the


In [26]:
#2-3-2. drafter가 블록 제안만 먼저 보기

In [27]:
block_size = 3
block = draft_block(
    drafter, ctx,
    block_size=block_size,
    do_sample=True, top_k=50, temperature=0.8,
    pad_token_id=tok.eos_token_id
)
print("draft ids:", block.tolist()[0])
print("draft text:", tok.decode(block[0], skip_special_tokens=True))


draft ids: [465, 2802, 550]
draft text:  his mother had


In [28]:
#2-3-3. 같은 길이만큼 verifier의 greedy 시퀀스만 계산해서 보기

In [29]:
greedy_seq = []
gctx = ctx.clone()
with torch.no_grad():
    for _ in range(block.shape[1]):
        out = verifier(input_ids=gctx)
        g = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        greedy_seq.append(int(g.item()))
        gctx = torch.cat([gctx, g], dim=1)

print("greedy ids:", greedy_seq)
print("greedy text (참고):", tok.decode(greedy_seq, skip_special_tokens=True))

greedy ids: [262, 995, 561]
greedy text (참고):  the world would


In [30]:
#2-3-4. 앞에서부터 한 토큰씩 비교해서 accept할지 결정 (루프를 눈으로 확인)

In [31]:
accepted_ids = []
cur = ctx.clone()

for i in range(block.shape[1]):
    d_id = int(block[0, i].item())
    g_id = greedy_seq[i]
    print(f"[i={i}] draft={d_id} vs greedy={g_id} ({tok.decode([d_id])} vs {tok.decode([g_id])})")

    if d_id == g_id:
        # 일치 → draft 토큰 수락
        accepted_ids.append(d_id)
        cur = torch.cat([cur, block[:, i:i+1]], dim=1)
        print("  → accept draft")
    else:
        # 불일치 → greedy 토큰으로 대체하고 중단
        g = torch.tensor([[g_id]], device=cur.device)
        cur = torch.cat([cur, g], dim=1)
        print("  → mismatch, take greedy and stop")
        break

print("accepted prefix:", len(accepted_ids), "/", block.shape[1])


[i=0] draft=465 vs greedy=262 ( his vs  the)
  → mismatch, take greedy and stop
accepted prefix: 0 / 3


In [32]:
#2-3-5. 한 스텝 끝난 뒤의 텍스트 확인

In [33]:
print("new text after one step:\n", tok.decode(cur[0], skip_special_tokens=True))

new text after one step:
 In a distant future, the


In [34]:
#2-4-1. 한 스텝만 수행하는 작은 함수(pva_step) 만들기

In [35]:
import torch

def pva_step(ctx, block_size=3):
    """
    현재 컨텍스트(ctx)에서 drafter가 block 제안 → verifier로 앞에서부터 비교 →
    일치하는 prefix 수락, 첫 불일치에서 greedy 대체 후 종료.
    returns: new_ctx, accepted_len (이번 스텝에서 수락된 토큰 수)
    """
    # 0) 디바이스 정렬
    model_device = next(drafter.parameters()).device
    ctx = ctx.to(model_device)

    # 1) draft (반복 억제 옵션 포함)
    block = draft_block(
        drafter, ctx,
        block_size=block_size,
        do_sample=True, top_k=50, temperature=0.8,
        pad_token_id=tok.eos_token_id,
        top_p=0.95, repetition_penalty=1.1, no_repeat_ngram_size=3
    )

    # 2) 같은 길이만큼 verifier의 greedy 시퀀스 (attention_mask 명시)
    greedy_seq = []
    gctx = ctx.clone().to(next(verifier.parameters()).device)
    with torch.no_grad():
        for _ in range(block.shape[1]):
            attn = (gctx != tok.pad_token_id).to(gctx.device)
            out = verifier(input_ids=gctx, attention_mask=attn)
            g = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
            greedy_seq.append(int(g.item()))
            gctx = torch.cat([gctx, g], dim=1)

    # 3) 앞에서부터 비교(accept or stop)
    accepted = 0
    cur = ctx.clone().to(gctx.device)
    for i in range(block.shape[1]):
        d_id = int(block[0, i].item())
        g_id = greedy_seq[i]
        if d_id == g_id:
            cur = torch.cat([cur, block[:, i:i+1]], dim=1)
            accepted += 1
        else:
            g = torch.tensor([[g_id]], device=cur.device)
            cur = torch.cat([cur, g], dim=1)
            break

    return cur, accepted

In [36]:
ctx = tok(prompt, return_tensors="pt", padding=True).input_ids.to(device)
new_ctx, accepted = pva_step(ctx, block_size=3)
print("accepted in this step:", accepted)
print(tok.decode(new_ctx[0], skip_special_tokens=True)[:200])

accepted in this step: 1
In a distant future, the world


In [39]:
import time, torch

def run_pva_loop(prompt, target_new_tokens=60, block_size=3, max_steps=500):
    # 1) padding=True로 만들어 attention_mask 일관성 확보
    base = tok(prompt, return_tensors="pt", padding=True).input_ids.to(device)
    ctx  = base.clone()

    proposed_total = 0
    accepted_total = 0
    steps = 0

    t0 = time.perf_counter()
    while (ctx.shape[1] - base.shape[1]) < target_new_tokens and steps < max_steps:
        steps += 1
        proposed_total += block_size

        try:
            ctx, accepted = pva_step(ctx, block_size=block_size)
        except RuntimeError as e:
            # 드문 OOM/디바이스 에러 대응: 캐시 비우고 block_size 줄여 재시도
            if "out of memory" in str(e).lower():
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                ctx, accepted = pva_step(ctx, block_size=max(1, block_size // 2))
            else:
                raise e

        accepted_total += accepted

        # EOS 조기 종료
        if ctx[0, -1].item() == tok.eos_token_id:
            break

        # 2) 간단 반복 가드: 마지막 두 줄이 동일하면 중단
        text_now = tok.decode(ctx[0], skip_special_tokens=True)
        lines = [s.strip() for s in text_now.splitlines() if s.strip()]
        if len(lines) >= 2 and lines[-1] == lines[-2]:
            break

    dt = time.perf_counter() - t0
    new_tokens = ctx.shape[1] - base.shape[1]
    text = tok.decode(ctx[0], skip_special_tokens=True)
    metrics = {
        "new_tokens": new_tokens,
        "time_sec": round(dt, 3),
        "tokens_per_sec": round(new_tokens / max(dt, 1e-9), 2),
        "proposed": proposed_total,
        "accepted": accepted_total,
        "acceptance_rate_%": round(100 * accepted_total / max(proposed_total, 1), 1),
        "steps": steps,
        "block_size": block_size,
    }
    return text, metrics


In [40]:
gen_text, m = run_pva_loop("In a distant future,", target_new_tokens=60, block_size=3)
print(gen_text[:400])
print(m)

In a distant future, the world would be a better place.

The world would be a better place.

The world would be a better place.
{'new_tokens': 28, 'time_sec': 4.181, 'tokens_per_sec': 6.7, 'proposed': 81, 'accepted': 1, 'acceptance_rate_%': 1.2, 'steps': 27, 'block_size': 3}
