In [32]:
from typing import Tuple

import numpy as np
import pandas as pd

import torch
from torch.nn.functional import softmax
from torch import FloatTensor, LongTensor

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

import warnings
warnings.filterwarnings('ignore')

import logging, sys
logging.disable(sys.maxsize)

In [33]:
# smaller draft model
draft_model = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
draft_tokenizer = AutoTokenizer.from_pretrained("gpt2")

# larger target model
target_model = AutoModelForCausalLM.from_pretrained("gpt2-medium", return_dict_in_generate=True)
target_tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")

In [34]:
def forward_pass(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> FloatTensor:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model(**inputs)
    return softmax(outputs.logits.squeeze())

In [35]:
## reference: https://discuss.huggingface.co/t/generation-probabilities-how-to-compute-probabilities-of-output-scores-for-gpt2/3175
def generate_next_token(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Tuple[LongTensor, FloatTensor]:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    inputs = tokenizer(prompt, return_tensors="pt")
    generated_outputs = model.generate(**inputs, do_sample=True, num_return_sequences=1, output_scores=True, 
                                            max_new_tokens=1)
    probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1)
    return generated_outputs, probs.squeeze()

In [36]:
prompt = "Today is a nice day"

In [39]:
def max_fn(x):
    x_max = np.where(x > 0, x, 0)
    return x_max / np.sum(x_max)

def speculative_sampling(prompt, draft_model, draft_tokenizer, target_model, target_tokenizer, N, K):
    # NOTE: paper indexes arrays starting from 1, python indexes from 0, so
    # we have to add an extra -1 term when indexing using n, T, or t
    n = len(draft_tokenizer(prompt, return_tensors="pt").input_ids[0])
    T = n + N
    probabilities = forward_pass(prompt, model=draft_model, tokenizer=draft_tokenizer).cpu().tolist()
    x_draft = draft_tokenizer(prompt).input_ids
    generation = target_tokenizer(prompt).input_ids
    num_llm_passes = 0
    while n < T:
        # Step 1: auto-regressive decode K tokens from draft model and get final p
        for _ in range(K):
            generated_outputs, probs = generate_next_token(prompt=prompt, model=draft_model, tokenizer=draft_tokenizer)
            prompt = draft_tokenizer.decode(generated_outputs.sequences[-1])
            x_draft.append(generated_outputs.sequences[0][-1])
            probabilities.append(probs)
        p = np.stack(probabilities)
        
#         # Step 2: target model forward passes on x_draft
        q = forward_pass(prompt, target_model, target_tokenizer).detach().cpu().numpy()
        num_llm_passes += 1

#         # Step 3: append draft tokens based on rejection criterion and resample
#         # a token on rejection
        all_accepted = True
        for _ in range(K):
            i = n - 1
            j = x_draft[i + 1]
            if np.random.random() < min(1, q[i][j] / p[i][j]) and n < T:  # accepted
                generation.append(j)
                n += 1
            else:  # rejected
                if n < T:
                    generation.append(np.argmax(max_fn(q[i] - p[i])))
                    n += 1
                all_accepted = False
                break

#         # Step 4: if all draft tokens were accepted, sample a final token
        if all_accepted:
            prompt = target_tokenizer.decode(generation)
        # just keeping my sanity
        assert n == len(generation), f"{n} {len(generation)}"

    return target_tokenizer.decode(generation), num_llm_passes

In [43]:
N = 15 # new_tokens
K = 15 # new_tokens_per_draft_generation
generation, num_llm_passes = speculative_sampling(prompt, draft_model, draft_tokenizer, target_model, target_tokenizer, N=N, K=K)

print(f"""
Prompt:

{prompt}

Generation following the prompt of new {N} tokens with {num_llm_passes} LLM forward passes: 

`{generation}`
""")


Prompt:

Today is a nice day

Generation following the prompt of new 15 tokens with 2 LLM forward passes: 

`Today is a nice day, but gives hope of better days to those of us who have been there`



**References:**

1. Accelerating Large Language Model Decoding with Speculative Sampling, https://arxiv.org/abs/2302.01318
2. https://jaykmody.com/blog/speculative-sampling/
