In [1]:
!uv add transformers accelerate

[2mResolved [1m75 packages[0m [2min 13ms[0m[0m
[2mAudited [1m55 packages[0m [2min 19ms[0m[0m


In [1]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

  from .autonotebook import tqdm as notebook_tqdm


In [30]:
draft_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct", device_map="auto")
draft_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct")
prompts = [
    "The 50 states of the USA in alphabetical order are: ", 
    "The countries of South America in alphabetical order are: ",
]

In [31]:
import torch
import platform
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
model.eval()
num_draft_tokens = 8
greedy_gen = GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=8)
if torch.cuda.is_available():
  device_type = "cuda"
elif platform.system() == "Darwin" and getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
  device_type = "mps"
else:
  device_type = "cpu"

Some parameters are on the meta device because they were offloaded to the disk.


In [32]:
import random
prompt = random.sample(prompts, 1)[0]
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device_type)
draft_inputs = draft_tokenizer.encode(prompt, return_tensors="pt").to(device_type)
assert torch.equal(inputs, draft_inputs)

In [33]:
output_tokens = draft_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256))[0]
draft_tokenizer.decode(output_tokens)

'The 50 states of the USA in alphabetical order are: \nAlabama, Alaska, Arizona, Arkansas, California, Colorado, Connecticut, Delaware, Florida, Georgia, Hawaii, Idaho, Illinois, Indiana, Iowa, Kansas, Kentucky, Louisiana, Maine, Maryland, Massachusetts, Michigan, Minnesota, Mississippi, Missouri, Montana, Nebraska, Nevada, New Hampshire, New Jersey, New Mexico, Ohio, Oklahoma, Oregon, Pennsylvania, Rhode Island, South Carolina, Tennessee, Texas, Utah, Vermont, Virginia, Washington, West Virginia, Wisconsin, and Wyoming.<|im_end|>'

In [34]:
output_tokens = model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256))[0]
tokenizer.decode(output_tokens)

'The 50 states of the USA in alphabetical order are: \nAlabama, Alaska, Arizona, Arkansas, California, Colorado, Connecticut, Delaware, Florida, Georgia, Hawaii, Idaho, Illinois, Indiana, Iowa, Kansas, Kentucky, Louisiana, Maine, Maryland, Massachusetts, Michigan, Minnesota, Mississippi, Missouri, Montana, Nebraska, Nevada, New Hampshire, New Jersey, New Mexico, New York, North Carolina, North Dakota, Ohio, Oklahoma, Oregon, Pennsylvania, Rhode Island, South Carolina, South Dakota, Tennessee, Texas, Utah, Vermont, Virginia, Washington, West Virginia, Wisconsin, Wyoming.<|im_end|>'

In [17]:
%%timeit
output_tokens = model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256))[0]

21.9 s ± 730 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%%timeit
draft_output_tokens = draft_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256))[0]


2.47 s ± 104 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
def speculative_decoding(prompt, max_new_tokens) -> torch.LongTensor:
    stop_token = tokenizer.eos_token_id
    do_stop = False
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(device_type)
    original_prompt_len = inputs.shape[1]
    while not do_stop:
        prompt_len = inputs.shape[1]
        # print(f"Original prompt length is {prompt_len}")
        draft_tokens = draft_model.generate(inputs=inputs, generation_config=greedy_gen)[:, prompt_len: ]
        validation_inputs = torch.cat([inputs, draft_tokens], dim=1)
        with torch.no_grad():
            logits = model(validation_inputs).logits[:, prompt_len-1:]
            probs = torch.nn.functional.softmax(logits, dim=-1)
            model_predicted_tokens = torch.argmax(probs, dim=-1)
            
            # Get probs of last non-pad token only
        # draft_token_probs = torch.gather(
        #     probs[:, :-1],
        #     dim=-1,
        #     index=draft_tokens.view(draft_tokens.shape[1], -1),
        # )
        if not torch.all(model_predicted_tokens[:, :-1] == draft_tokens):
            #print(f"Generated draft tokens {draft_tokenizer.batch_decode(draft_tokens)}")
            mismatch = torch.argwhere(model_predicted_tokens[:, :-1] != draft_tokens)[0][1]
            mismatched_token = model_predicted_tokens[0, mismatch]
            matched_draft_tokens = draft_tokens[:, :mismatch]
            #print(f"Draft model predicted: {(draft_tokens[0, mismatch] ,tokenizer.decode([draft_tokens[0, mismatch]], skip_special_tokens=False))}")
            #print(f"Model predicted : {(mismatched_token, tokenizer.decode([mismatched_token], skip_special_tokens=False))}")
            inputs =  torch.cat([inputs, matched_draft_tokens, mismatched_token.unsqueeze(0).unsqueeze(0)], dim=1)
        else:
            inputs =  torch.cat([inputs, draft_tokens, model_predicted_tokens[0, -1].unsqueeze(0).unsqueeze(0)], dim=1)

        if tokenizer.eos_token_id in inputs or (max_new_tokens - (inputs.shape[1] - original_prompt_len)) < 0:
            do_stop=True
        #print(f"Next inputs are : {tokenizer.batch_decode(inputs)}")
    final_answer = inputs[0, : min(torch.argwhere(inputs[0]==tokenizer.eos_token_id)+1, max_new_tokens)]
    return final_answer

In [36]:
speculative_decoding_output_tokens = speculative_decoding(prompt, 256)

In [37]:
assert torch.equal(output_tokens, speculative_decoding_output_tokens)