### Speculative Decoding (Scratch)

In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Load tokenizer and main model (GPT-2)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

small_model = GPT2LMHeadModel.from_pretrained("gpt2")
big_model = GPT2LMHeadModel.from_pretrained("gpt2-large")

small_model.eval()
big_model.eval()

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3840, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=1280)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=5120, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=5120)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [2]:
# Move them to either to CPU or GPU based on availability

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
small_model.to(device)
big_model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3840, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=1280)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=5120, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=5120)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [10]:
# Input text to be tokenized

prompt = "Once upon a time"
input_ids = tokenizer.encode(prompt, return_tensors = "pt").to(device)

In [13]:
# Parameters for speculative decoding
max_generated_tokens = 50     # Maximum number of tokens to generate
accept_top_k = 5              # Big model must agree (token ∈ top-k) to accept draft token
draft_length = 5              # Number of draft tokens to propose at once

# Start output with original input (clone to avoid modifying original input_ids)
output_ids = input_ids.clone()

# Continue until desired number of tokens is generated
while output_ids.shape[1] < input_ids.shape[1] + max_generated_tokens:
    with torch.no_grad():  # Disable gradient calculation for faster inference

        # Initialize draft tokens generation using the small model
        draft_ids = output_ids.clone()  # Start draft generation from current output
        draft_tokens = []               # Store proposed tokens

        for _ in range(draft_length):
            # Get logits for next token from small model
            small_logits = small_model(draft_ids).logits[:, -1, :]
            # Pick the most likely token (greedy decoding)
            small_next_token = torch.argmax(small_logits, dim=-1, keepdim=True)
            draft_tokens.append(small_next_token)
            # Append proposed token to draft sequence
            draft_ids = torch.cat([draft_ids, small_next_token], dim=1)

        # Evaluate the entire draft sequence using the big model
        big_logits = big_model(draft_ids[:, :-1]).logits  # Predict all except last token
        big_probs = torch.softmax(big_logits, dim=-1)     # Convert logits to probabilities

        accept_count = 0  # Count how many draft tokens are accepted

        # For each draft token, check whether big model agrees
        for i, token in enumerate(draft_tokens):
            # Get the big model's probability distribution at the corresponding token index
            big_token_probs = big_probs[:, output_ids.shape[1] + i, :]  # Offset by current length
            topk = torch.topk(big_token_probs, k=accept_top_k, dim=-1)  # Top-k probable tokens

            # If small model's token is in big model's top-k → accept it
            if token.item() in topk.indices[0]:
                output_ids = torch.cat([output_ids, token], dim=1)
                accept_count += 1
            else:
                break  # Stop on first rejection

        # If any token was rejected, let big model generate the next token instead
        if accept_count < draft_length:
            next_input = output_ids
            big_logits = big_model(next_input).logits[:, -1, :]
            big_probs = torch.softmax(big_logits, dim=-1)
            big_next_token = torch.argmax(big_probs, dim=-1, keepdim=True)
            output_ids = torch.cat([output_ids, big_next_token], dim=1)

# Decode and print the generated output
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))


Once upon a time, the world was a very different place.

The world of the past was a world of peace and prosperity.

The world of the future was a world of war and strife.

The world of the present was a world of


### Without Speculative Decoding (distilgpt2 vs gpt2-xl) - Inbuilt

In [8]:
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load tokenizer and main model (GPT2-XL)
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")

# Input text to be tokenized
input_text = "The history of artificial intelligence begins in the 1950s with the early ideas of machines mimicking human reasoning."

# Convert input text to token IDs and attention mask
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask

# Set pad token to EOS (GPT-2 has no explicit pad token)
pad_token_id = tokenizer.eos_token_id

# Measure generation time
start_time = time.time()

output = model.generate(
    input_ids,
    attention_mask=attention_mask,
    max_new_tokens=5,
    pad_token_id=pad_token_id,
)
end_time = time.time()

# Decode and print the output
print("Output without speculative decoding:\n")
print(tokenizer.decode(output[0], skip_special_tokens=True))
print(f"\nTime taken: {end_time - start_time:.4f} seconds")


Output without speculative decoding:

The history of artificial intelligence begins in the 1950s with the early ideas of machines mimicking human reasoning. In the 1960s,

Time taken: 162.0421 seconds


### With Speculative Decoding (distilgpt2 vs gpt2-xl) - Inbuilt


In [9]:
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load tokenizer, small assistant model, and main large model
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
small_model = AutoModelForCausalLM.from_pretrained("distilgpt2")  # Faster, smaller assistant
big_model = AutoModelForCausalLM.from_pretrained("gpt2-xl")       # Main accurate model

# Input text to tokenize
input_text = "The history of artificial intelligence begins in the 1950s with the early ideas of machines mimicking human reasoning."

# Convert input text to token IDs and attention mask
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
pad_token_id = tokenizer.eos_token_id

# Measure generation time
start_time = time.time()

output = big_model.generate(
    input_ids,
    attention_mask=attention_mask,
    max_new_tokens=5,
    pad_token_id=pad_token_id,
    assistant_model=small_model,
)

end_time = time.time()

# Decode and print the output
print("Output with speculative decoding:\n")
print(tokenizer.decode(output[0], skip_special_tokens=True))
print(f"\nTime taken: {end_time - start_time:.4f} seconds")


Output with speculative decoding:

The history of artificial intelligence begins in the 1950s with the early ideas of machines mimicking human reasoning. In the 1960s,

Time taken: 99.5746 seconds
