In [24]:
# Setup: imports & seed
import random
from collections import Counter, defaultdict
from typing import List, Tuple, Dict

random.seed(42)  # reproducibility

In [25]:
# Define a tiny toy corpus
corpus = "the wolf ran into the forest"

In [26]:
# Tokenization (very simple: lowercase + split on whitespace)
def tokenize(text: str) -> List[str]:
    return text.lower().split()

tokens = tokenize(corpus)
tokens

['the', 'wolf', 'ran', 'into', 'the', 'forest']

In [27]:
#유틸: 컨텍스트 안전 확보

In [28]:
# Utility: ensure we have >= 2 tokens for trigram-based backoff to work
def ensure_two_token_context(seq: List[str]) -> List[str]:
    if len(seq) < 2:
        return [seq[-1], seq[-1]]
    return seq

In [29]:
#n-gram build

In [30]:
# Build unigram table
def build_unigram(tokens: List[str]) -> Counter:
    return Counter(tokens)

# Initialize empty bigram/trigram structures
def init_bigram() -> Dict[str, Counter]:
    return defaultdict(Counter)

def init_trigram() -> Dict[Tuple[str, str], Counter]:
    return defaultdict(Counter)

# Fill bigram counts
def fill_bigram(bi: Dict[str, Counter], tokens: List[str]) -> None:
    for a, b in zip(tokens, tokens[1:]):
        bi[a][b] += 1

# Fill trigram counts
def fill_trigram(tri: Dict[Tuple[str, str], Counter], tokens: List[str]) -> None:
    for a, b, c in zip(tokens, tokens[1:], tokens[2:]):
        tri[(a, b)][c] += 1

In [31]:
#n-gram   cration n execution

In [32]:
uni = build_unigram(tokens)
bi  = init_bigram()
tri = init_trigram()

fill_bigram(bi, tokens)
fill_trigram(tri, tokens)

In [33]:
# n-gram confirm

In [34]:
def print_unigrams(uni: Counter) -> None:
    print("=== Unigrams ===")
    for w, c in uni.items():
        print(f"{w!r}: {c}")

def print_bigrams(bi: Dict[str, Counter]) -> None:
    print("\n=== Bigrams ===")
    for prev, counter in bi.items():
        for nxt, c in counter.items():
            print(f"({prev!r} -> {nxt!r}): {c}")

def print_trigrams(tri: Dict[Tuple[str, str], Counter]) -> None:
    print("\n=== Trigrams ===")
    for (w1, w2), counter in tri.items():
        for nxt, c in counter.items():
            print(f"(({w1!r}, {w2!r}) -> {nxt!r}): {c}")

print_unigrams(uni)
print_bigrams(bi)
print_trigrams(tri)

=== Unigrams ===
'the': 2
'wolf': 1
'ran': 1
'into': 1
'forest': 1

=== Bigrams ===
('the' -> 'wolf'): 1
('the' -> 'forest'): 1
('wolf' -> 'ran'): 1
('ran' -> 'into'): 1
('into' -> 'the'): 1

=== Trigrams ===
(('the', 'wolf') -> 'ran'): 1
(('wolf', 'ran') -> 'into'): 1
(('ran', 'into') -> 'the'): 1
(('into', 'the') -> 'forest'): 1


In [35]:
#Backoff: try trigram -> else bigram -> else unigram

In [36]:
def get_counts(prev2: str, prev1: str) -> Counter:
    d3 = tri.get((prev2, prev1))
    if d3:
        return d3
    d2 = bi.get(prev1)
    if d2:
        return d2
    return uni

In [37]:
#샘플링 & 최빈값 선택(분리)

In [38]:
# Probabilistic sampling with temperature
def sample_from_counts(dist: Counter, T: float = 1.0) -> str:
    items  = list(dist.items())
    toks   = [t for t, _ in items]
    cnts   = [c for _, c in items]
    weights = [(c if c > 0 else 1e-9) ** (1.0 / T) for c in cnts]
    return random.choices(toks, weights=weights, k=1)[0]

# Deterministic (argmax)
def argmax_from_counts(dist: Counter) -> str:
    return max(dist.items(), key=lambda kv: kv[1])[0]

In [39]:
#10 — 베이스라인(비교용): 한 토큰씩 생성

In [40]:
def generate_baseline(prompt_tokens: List[str], steps: int = 5, T: float = 0.7) -> List[str]:
    """
    Baseline: use backoff n-gram to sample next token step by step.
    """
    out = list(prompt_tokens)
    out = ensure_two_token_context(out)
    prev2, prev1 = out[-2], out[-1]

    for _ in range(steps):
        dist = get_counts(prev2, prev1)
        nxt  = sample_from_counts(dist, T)
        out.append(nxt)
        prev2, prev1 = prev1, nxt
    return out


In [41]:
#셀 11 — #4 스페큘레이티브(초극단 분해)

In [42]:
# ---------- Drafter (small model) ultra-split ----------

def get_drafter_dist(prev1: str) -> Counter:
    """Return small-model dist: bigram if available, else unigram."""
    return bi.get(prev1, uni)

def sample_drafter_token_from_dist(dist: Counter, T_draft: float) -> str:
    """Sample one token given a small-model distribution."""
    return sample_from_counts(dist, T_draft)

def draft_one(prev1: str, T_draft: float) -> str:
    """Get drafter dist, then sample one."""
    dist = get_drafter_dist(prev1)
    return sample_drafter_token_from_dist(dist, T_draft)

def advance_context(prev2: str, prev1: str, new_token: str) -> tuple:
    """Advance (prev2, prev1) window with a new_token."""
    return prev1, new_token

def make_draft_step(prev2: str, prev1: str, T_draft: float) -> tuple:
    """One drafting step: produce token and advanced context."""
    t = draft_one(prev1, T_draft)
    new_prev2, new_prev1 = advance_context(prev2, prev1, t)
    return t, new_prev2, new_prev1

def build_draft(context: list, k: int, T_draft: float) -> tuple:
    """
    Return:
      draft_tokens: [t1, ...]
      draft_trace : [(prev2, prev1, t), ...] before advancing
    """
    prev2, prev1 = context[-2], context[-1]
    draft, trace = [], []
    for _ in range(k):
        t, prev2, prev1 = make_draft_step(prev2, prev1, T_draft)
        trace.append((context[-2] if not draft else draft[-1],   # 읽기 쉬운 버전 (단순표시용)
                      context[-1] if not draft else t,           # 정확 추적 원하면 아래 주석 참고
                      t))
        # 정확한 컨텍스트 로그가 필요하면 위 2줄 대신:
        # trace.append((prev2_before, prev1_before, t))  # make_draft_step 들어가기 전 값을 저장해야 함
        draft.append(t)
    return draft, trace

In [43]:
# ---------- Verifier (large model) ultra-split ----------

def get_verifier_dist(prev2: str, prev1: str) -> Counter:
    """Large model backoff distribution."""
    return get_counts(prev2, prev1)

def predict_verifier_token(prev2: str, prev1: str) -> str:
    """Argmax next token from large model."""
    dist = get_verifier_dist(prev2, prev1)
    return argmax_from_counts(dist)

def decide_accept(draft_token: str, verify_token: str) -> bool:
    """Return True if draft_token matches verify_token."""
    return draft_token == verify_token

def make_verify_log_entry(prev2: str, prev1: str, draft_t: str, verify_t: str, ok: bool) -> tuple:
    """Pack one verification step into a tuple for logging."""
    return (prev2, prev1, draft_t, verify_t, ok)

def apply_prefix_accept(accepted: list, prev2: str, prev1: str, chosen_token: str, ok: bool) -> tuple:
    """
    Append chosen_token to accepted.
    If ok=True, advance context; else keep context (caller will stop).
    """
    accepted.append(chosen_token)
    if ok:
        return accepted, *advance_context(prev2, prev1, chosen_token)
    else:
        return accepted, prev2, prev1  # no advance on mismatch (stop afterwards)

def verify_one_step(prev2: str, prev1: str, draft_t: str) -> tuple:
    """
    One verification step:
    Returns (chosen_token, ok_flag, verify_token).
    """
    verify_t = predict_verifier_token(prev2, prev1)
    ok = decide_accept(draft_t, verify_t)
    chosen = draft_t if ok else verify_t
    return chosen, ok, verify_t

def prefix_accept_verify(context: list, draft: list) -> tuple:
    """
    Walk left->right; accept while ok; on first mismatch, replace & STOP.
    Returns:
      accepted_tokens
      verify_log [(p2,p1,draft_t,verify_t,ok), ...]
    """
    accepted, log = [], []
    prev2, prev1 = context[-2], context[-1]

    for draft_t in draft:
        chosen, ok, verify_t = verify_one_step(prev2, prev1, draft_t)
        log.append(make_verify_log_entry(prev2, prev1, draft_t, verify_t, ok))
        accepted, prev2, prev1 = apply_prefix_accept(accepted, prev2, prev1, chosen, ok)
        if not ok:
            break
    return accepted, log

In [44]:
# ---------- Pretty-print (trace) ultra-split ----------

def print_draft_trace(draft: list, draft_trace: list) -> None:
    print("=== Draft stage (small model) ===")
    for i, (p2, p1, t) in enumerate(draft_trace, 1):
        print(f"[D{i}] prev2='{p2}' prev1='{p1}' -> draft='{t}'")
    if not draft_trace and draft:
        print("(no trace recorded, draft =", draft, ")")

def print_verify_log(verify_log: list) -> None:
    print("\n=== Verify stage (large model, prefix-accept) ===")
    for i, (p2, p1, t, v, ok) in enumerate(verify_log, 1):
        status = "ACCEPT" if ok else "REPLACE+STOP"
        print(f"[V{i}] prev2='{p2}' prev1='{p1}'  draft='{t}'  verify='{v}'  ->  {status}")

def print_final(prompt_tokens: list, draft: list, accepted: list) -> None:
    print("\nDraft   :", draft)
    print("Accepted:", accepted)
    print("Final   :", " ".join(prompt_tokens + accepted))



In [45]:
# ---------- Orchestrator (speculative step) ----------

def speculative_step(prompt_tokens: list, k: int = 5, T_draft: float = 0.9, trace: bool = True):
    """
    Speculative (core): draft -> verify (prefix-accept)
      - Drafter: bigram (else unigram) + higher T for diversity
      - Verifier: trigram->bigram->unigram + argmax
      - Rule: first mismatch => replace & STOP
    Returns: (draft, accepted, final_sequence)
    """
    context = ensure_two_token_context(list(prompt_tokens))

    # 1) Build draft
    draft, draft_trace = build_draft(context, k=k, T_draft=T_draft)

    # 2) Verify with prefix-accept
    accepted, verify_log = prefix_accept_verify(context, draft)

    # 3) Optional trace
    if trace:
        print_draft_trace(draft, draft_trace)
        print_verify_log(verify_log)
        print_final(prompt_tokens, draft, accepted)

    return draft, accepted, prompt_tokens + accepted

In [1]:
prompt = ["the", "wolf", "ran"]

print("---- Baseline (T=0.7), next 5 tokens ----")
print("Baseline:", " ".join(generate_baseline(prompt, steps=5, T=0.7)))

print("\n---- Speculative (k=5, T_draft=0.9) ----")
_ = speculative_step(prompt, k=5, T_draft=0.9, trace=True)
