In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cpu"
print("device:", device)

device: cpu


In [3]:
#토크나이저와 모델 로드

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tok = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)

In [5]:
#prompt ready

In [6]:
prompt = "In a distant future,"
print("prompt:", prompt)

prompt: In a distant future,


In [7]:
#PAD/EOS 설정 점검

In [8]:
# distilgpt2는 기본 pad_token이 없어서 eos로 맞춰줍니다.
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

print("pad_token:", tok.pad_token)
print("eos_token:", tok.eos_token)

pad_token: <|endoftext|>
eos_token: <|endoftext|>


In [9]:
#입력 컨텍스트 토큰화

In [10]:
ctx = tok(prompt, return_tensors="pt").input_ids.to(device)
print("ctx shape:", ctx.shape)
print("ctx tokens:", ctx[0].tolist())
print("ctx decoded:", tok.decode(ctx[0]))

ctx shape: torch.Size([1, 5])
ctx tokens: [818, 257, 12899, 2003, 11]
ctx decoded: In a distant future,


In [11]:
#draft_block (점검용: drafter도 그리디로만 1블록 생성)

In [12]:
@torch.no_grad()
def draft_block_greedy(drafter, input_ids, block_size=3, pad_token_id=None):
    # CPU에서 돌아가니 디바이스 정렬만 가볍게
    input_ids = input_ids.to("cpu")
    # pad==eos 환경에서는 attention_mask를 명시해주는 게 안전
    attn = (input_ids != tok.pad_token_id)

    out = drafter.generate(
        input_ids,
        attention_mask=attn,
        max_new_tokens=block_size,
        do_sample=False,          # ← 샘플링 끔 (sanity용)
        pad_token_id=pad_token_id
    )
    # 프롬프트 이후 새로 나온 블록만 반환
    return out[:, input_ids.shape[1]:]   # shape: [1, block_size]


In [13]:
#drafter 모델 로드

In [14]:
from transformers import AutoModelForCausalLM

# drafter: 작은 모델 (distilgpt2)
drafter = AutoModelForCausalLM.from_pretrained("distilgpt2").to("cpu")

print("drafter 준비 완료 ✅")

drafter 준비 완료 ✅


In [15]:
#첫 블록 제안 테스트

In [16]:
# ctx는 이전 셀에서 만든 토큰 시퀀스
block = draft_block_greedy(drafter, ctx, block_size=3, pad_token_id=tok.eos_token_id)

print("draft ids:", block[0].tolist())
print("draft decoded:", tok.decode(block[0], skip_special_tokens=True))

draft ids: [262, 995, 318]
draft decoded:  the world is


In [17]:
#verifier 로드 (점검 모드: drafter와 같은 모델)

In [18]:
from transformers import AutoModelForCausalLM

verifier = AutoModelForCausalLM.from_pretrained("distilgpt2").to("cpu")
print("verifier 준비 완료")


verifier 준비 완료


In [19]:
#Verifier 그리디 시퀀스 생성

In [20]:
greedy_seq = []
gctx = ctx.clone()

with torch.no_grad():
    for _ in range(block.shape[1]):  # block 길이만큼 반복
        out = verifier(input_ids=gctx)
        g = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        greedy_seq.append(int(g.item()))
        gctx = torch.cat([gctx, g], dim=1)

print("greedy ids:", greedy_seq)
print("greedy decoded:", tok.decode(greedy_seq, skip_special_tokens=True))


greedy ids: [262, 995, 318]
greedy decoded:  the world is


In [21]:
#드래프트 vs 그리디를 앞에서부터 비교하며 수락

In [22]:
accepted_ids = []
cur = ctx.clone()

for i in range(block.shape[1]):
    d_id = int(block[0, i].item())
    g_id = greedy_seq[i]
    print(f"[{i}] draft={tok.decode([d_id])!r} vs greedy={tok.decode([g_id])!r} ->", end=" ")

    if d_id == g_id:
        # 일치 → draft 토큰 수락
        accepted_ids.append(d_id)
        cur = torch.cat([cur, block[:, i:i+1]], dim=1)
        print("ACCEPT")
    else:
        # 불일치 → greedy 토큰으로 대체하고 중단
        g = torch.tensor([[g_id]], device=cur.device)
        cur = torch.cat([cur, g], dim=1)
        print("MISMATCH -> take greedy and STOP")
        break

print("accepted count:", len(accepted_ids), "/", block.shape[1])


[0] draft=' the' vs greedy=' the' -> ACCEPT
[1] draft=' world' vs greedy=' world' -> ACCEPT
[2] draft=' is' vs greedy=' is' -> ACCEPT
accepted count: 3 / 3


In [23]:
#한 스텝 후 텍스트 확인

In [24]:
print("new text:\n", tok.decode(cur[0], skip_special_tokens=True))

new text:
 In a distant future, the world is


In [25]:
#한 스텝 함수(pva_step_once) 정의

In [26]:
@torch.no_grad()
def pva_step_once(ctx, block_size=3):
    """
    drafter가 block_size 토큰 제안 → verifier가 같은 길이만큼 greedy 예측 →
    앞에서부터 일치하는 prefix만 수락, 첫 불일치에서 greedy 토큰 붙이고 종료.
    returns: new_ctx, accepted_count
    """
    # 1) draft (지금은 점검모드: 샘플링 없이 그리디)
    block = draft_block_greedy(drafter, ctx, block_size=block_size, pad_token_id=tok.eos_token_id)

    # 2) verifier greedy 시퀀스
    greedy_seq = []
    gctx = ctx.clone()
    for _ in range(block.shape[1]):
        out = verifier(input_ids=gctx)
        g = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        greedy_seq.append(int(g.item()))
        gctx = torch.cat([gctx, g], dim=1)

    # 3) 비교/수락
    accepted = 0
    cur = ctx.clone()
    for i in range(block.shape[1]):
        d_id = int(block[0, i].item())
        g_id = greedy_seq[i]
        if d_id == g_id:
            cur = torch.cat([cur, block[:, i:i+1]], dim=1)
            accepted += 1
        else:
            g = torch.tensor([[g_id]], device=cur.device)
            cur = torch.cat([cur, g], dim=1)
            break
    return cur, accepted


In [27]:
#한 스텝 더 실행해보기

In [28]:
cur2, acc2 = pva_step_once(cur, block_size=3)
print("accepted this step:", acc2)
print(tok.decode(cur2[0], skip_special_tokens=True))


accepted this step: 3
In a distant future, the world is in a state


In [29]:
#짧은 루프 (5스텝만) 실행

In [30]:
ctx_loop = ctx.clone()
for step in range(5):
    ctx_loop, acc = pva_step_once(ctx_loop, block_size=3)
    print(f"[step {step+1}] accepted={acc}, text='{tok.decode(ctx_loop[0], skip_special_tokens=True)}'")


[step 1] accepted=3, text='In a distant future, the world is'
[step 2] accepted=3, text='In a distant future, the world is in a state'
[step 3] accepted=3, text='In a distant future, the world is in a state of flux.'
[step 4] accepted=3, text='In a distant future, the world is in a state of flux. The world is'
[step 5] accepted=3, text='In a distant future, the world is in a state of flux. The world is in a state'


In [31]:
#샘플링 버전 draft 함수 추가

In [32]:
@torch.no_grad()
def draft_block_sampled(
    drafter, input_ids, block_size=3,
    top_k=20, top_p=0.9, temperature=0.7,
    pad_token_id=None, repetition_penalty=1.05, no_repeat_ngram_size=3
):
    input_ids = input_ids.to("cpu")
    attn = (input_ids != tok.pad_token_id)

    out = drafter.generate(
        input_ids,
        attention_mask=attn,
        max_new_tokens=block_size,
        do_sample=True,
        top_k=top_k, top_p=top_p, temperature=temperature,
        pad_token_id=pad_token_id,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
    )
    return out[:, input_ids.shape[1]:]


In [33]:
#현실 모드 한 스텝(pva_step_real)

In [34]:
@torch.no_grad()
def pva_step_real(ctx, block_size=3):
    # drafter: guided drafting 사용
    block = draft_block_guided(
        drafter, verifier, ctx, block_size=block_size,
        guide_topk=20,        # 10~30 사이로 조절해 보세요
        sample_topk=40, top_p=0.95, temperature=0.7,
        pad_token_id=tok.eos_token_id,
        repetition_penalty=1.05, no_repeat_ngram_size=3
    )

    # 이하 동일 (verifier greedy)
    greedy_seq = []
    gctx = ctx.clone()
    for _ in range(block.shape[1]):
        attn = (gctx != tok.pad_token_id)
        out = verifier(input_ids=gctx, attention_mask=attn)
        g = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        greedy_seq.append(int(g.item()))
        gctx = torch.cat([gctx, g], dim=1)

    accepted = 0
    cur = ctx.clone()
    for i in range(block.shape[1]):
        d_id = int(block[0, i])
        g_id = greedy_seq[i]
        if d_id == g_id:
            cur = torch.cat([cur, block[:, i:i+1]], dim=1)
            accepted += 1
        else:
            g = torch.tensor([[g_id]], device=cur.device)
            cur = torch.cat([cur, g], dim=1)
            break
    return cur, accepted


In [41]:
#draft_block_guided 정의

In [42]:
import torch

@torch.no_grad()
def draft_block_guided(
    drafter, verifier, input_ids,
    block_size=3,
    guide_topk=20,            # verifier 상위 K 토큰만 허용
    sample_topk=40, top_p=0.95, temperature=0.7,
    pad_token_id=None, repetition_penalty=1.05, no_repeat_ngram_size=3
):
    """
    verifier의 상위-K 후보만 허용해서 drafter가 그 안에서 샘플링하도록 가이드.
    일치 확률을 높여 acceptance를 끌어올리는 테크닉.
    """
    ctx = input_ids.clone()

    for _ in range(block_size):
        # 1) verifier 분포 상위 K 토큰 집합
        v_out = verifier(input_ids=ctx, attention_mask=(ctx != tok.pad_token_id))
        v_logits = v_out.logits[:, -1, :]
        v_topk = torch.topk(v_logits, k=min(guide_topk, v_logits.shape[-1]), dim=-1).indices[0]
        v_allowed = set(v_topk.tolist())

        # 2) drafter 분포
        d_out = drafter(input_ids=ctx, attention_mask=(ctx != tok.pad_token_id))
        d_logits = d_out.logits[:, -1, :].squeeze(0)

        # 3) verifier 상위집합 외 토큰은 -inf로 마스킹
        mask = torch.full_like(d_logits, float("-inf"))
        mask[list(v_allowed)] = 0.0
        guided_logits = d_logits + mask

        # 4) guided 분포에서 샘플링(top-k+temperature; 필요시 top_p로 대체 가능)
        k = int(min(sample_topk, (guided_logits != float("-inf")).sum().item()))
        if k <= 0:
            # 가끔 상위집합이 너무 작을 때 대비: drafter 그리디로 fallback
            next_id = torch.argmax(d_logits).view(1,1)
        else:
            vals, idxs = torch.topk(guided_logits, k=k, dim=-1)
            probs = torch.softmax(vals / max(temperature, 1e-6), dim=-1)
            pick_idx = torch.multinomial(probs, num_samples=1).item()
            next_id = idxs[pick_idx].view(1,1)

        ctx = torch.cat([ctx, next_id], dim=1)

    # 프롬프트 이후 block만 반환
    return ctx[:, input_ids.shape[1]:]


In [35]:
#현실 모드 한 스텝만 시험

In [38]:
ctx_real = ctx.clone()
ctx_real, acc_real = pva_step_real(ctx_real, block_size=3)
print("accepted (real step):", acc_real)
print(tok.decode(ctx_real[0], skip_special_tokens=True))

accepted (real step): 1
In a distant future, the world


In [39]:
#짧은 루프(현실 모드, 8스텝)로 감

In [40]:
def run_short_real(ctx0, steps=8, block_size=3):
    ctx = ctx0.clone()
    proposed = accepted = 0
    for i in range(steps):
        ctx, acc = pva_step_real(ctx, block_size=block_size)
        proposed += block_size
        accepted += acc
        print(f"[{i+1}] accepted={acc}, acc_rate_so_far={round(100*accepted/proposed,1)}%")
    print("text:")
    print(tok.decode(ctx[0], skip_special_tokens=True))
    return ctx

_ = run_short_real(ctx, steps=8, block_size=2)


[1] accepted=0, acc_rate_so_far=0.0%
[2] accepted=0, acc_rate_so_far=0.0%
[3] accepted=0, acc_rate_so_far=0.0%
[4] accepted=0, acc_rate_so_far=0.0%
[5] accepted=0, acc_rate_so_far=0.0%
[6] accepted=0, acc_rate_so_far=0.0%
[7] accepted=2, acc_rate_so_far=14.3%
[8] accepted=1, acc_rate_so_far=18.8%
text:
In a distant future, the world is in a state of flux. The
