In [5]:
!uv add transformers

[2mResolved [1m74 packages[0m [2min 4ms[0m[0m
[2mAudited [1m54 packages[0m [2min 5ms[0m[0m


In [11]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

In [9]:
draft_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct")
draft_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct")

In [30]:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
greedy_gen = GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=8)

In [31]:

prompt = "The 50 states of the USA in alphabetical order are: "
inputs = tokenizer.encode(prompt, return_tensors="pt")

In [32]:
output_tokens = draft_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256))[0]
tokenizer.decode(output_tokens)

'The 50 states of the USA in alphabetical order are: \nAlabama, Alaska, Arizona, Arkansas, California, Colorado, Connecticut, Delaware, Florida, Georgia, Hawaii, Idaho, Illinois, Indiana, Iowa, Kansas, Kentucky, Louisiana, Maine, Maryland, Massachusetts, Michigan, Minnesota, Mississippi, Missouri, Montana, Nebraska, Nevada, New Hampshire, New Jersey, New Mexico, Ohio, Oklahoma, Oregon, Pennsylvania, Rhode Island, South Carolina, Tennessee, Texas, Utah, Vermont, Virginia, Washington, West Virginia, Wisconsin, and Wyoming.<|im_end|>'

In [20]:
%%timeit
output_tokens = model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256))[0]


12.3 s ± 264 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%%timeit
draft_output_tokens = draft_model.generate(inputs=inputs, generation_config=GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=256))[0]


2.78 s ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [246]:
import torch
num_draft_tokens = 8
threshold_prob = 0.7
stop_token = tokenizer.eos_token_id
do_stop = False
prompt = "The 50 states of the USA in alphabetical order are: "
inputs = tokenizer.encode(prompt, return_tensors="pt")
while not do_stop:
    original_prompt_len = inputs.shape[1]
    #print(f"Original prompt length is {original_prompt_len}")
    draft_tokens = draft_model.generate(inputs=inputs, generation_config=greedy_gen)[:, original_prompt_len: ]
    #print(f"Generated draft tokens {draft_tokenizer.batch_decode(draft_tokens)}")
    rep_draft_tokens = draft_tokens.repeat_interleave(draft_tokens.shape[1], 0).tril()
    validation_inputs = torch.cat([inputs.repeat_interleave(draft_tokens.shape[1], 0), rep_draft_tokens], dim=1)
    validation_inputs[validation_inputs==0] = tokenizer.pad_token_id
    with torch.no_grad():
        logits = model(validation_inputs).logits[:, original_prompt_len-1:-1]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # Get probs of last non-pad token only
        probs = probs[torch.arange(0, draft_tokens.shape[1]), torch.arange(0, draft_tokens.shape[1])]
    draft_token_probs = torch.gather(
        probs,
        dim=-1,
        index=draft_tokens.view(draft_tokens.shape[1], -1),
    )
    if not torch.all(draft_token_probs >= threshold_prob):
        potential_mismatches = torch.argwhere((draft_token_probs<threshold_prob).int())
        # print(f"Potential mismatches : {potential_mismatches}")
        is_mismatched_token = False
        for mismatch in potential_mismatches:
            model_predicted_token = torch.argmax(probs[mismatch[0]])
            # print(f"Model predicted token {model_predicted_token}")
            if model_predicted_token.item() != draft_tokens[0, mismatch[0]].item():
                is_mismatched_token = True
                print(f"Generated draft tokens {draft_tokenizer.batch_decode(draft_tokens)}")
                print(f"Correctly predicted draft tokens: {tokenizer.batch_decode(draft_tokens[:, original_prompt_len: original_prompt_len+mismatch[1]])}")
                accepted_draft_tokens = draft_tokens[:,  :mismatch[0]]
                print(f"Incorrect draft token: {tokenizer.convert_ids_to_tokens(draft_tokens[0, mismatch[0]].item())}")
                print(f"Correct token by parent is: {tokenizer.convert_ids_to_tokens(model_predicted_token.item())}")
                inputs = torch.cat([inputs, accepted_draft_tokens, model_predicted_token.unsqueeze(0).unsqueeze(0)], dim=1)
                print(f"Next inputs are : {tokenizer.batch_decode(inputs)}")
        if not is_mismatched_token:
            inputs =  torch.cat([inputs, draft_tokens], dim=1)
    else:
        inputs = torch.cat([inputs, draft_tokens], dim=1)
    
    if tokenizer.eos_token_id in inputs:
        do_stop=True
    
    #print(f"Next inputs are : {tokenizer.batch_decode(inputs)}")
tokenizer.batch_decode(inputs)

Generated draft tokens [', New Mexico, Ohio, Oklahoma,']
Correctly predicted draft tokens: ['']
Incorrect draft token: ĠOhio
Correct token by parent is: ĠNew
Next inputs are : ['The 50 states of the USA in alphabetical order are: \nAlabama, Alaska, Arizona, Arkansas, California, Colorado, Connecticut, Delaware, Florida, Georgia, Hawaii, Idaho, Illinois, Indiana, Iowa, Kansas, Kentucky, Louisiana, Maine, Maryland, Massachusetts, Michigan, Minnesota, Mississippi, Missouri, Montana, Nebraska, Nevada, New Hampshire, New Jersey, New Mexico, New']
Generated draft tokens [' York, North Carolina, Ohio, Oklahoma']
Correctly predicted draft tokens: ['']
Incorrect draft token: ĠOhio
Correct token by parent is: ĠNorth
Next inputs are : ['The 50 states of the USA in alphabetical order are: \nAlabama, Alaska, Arizona, Arkansas, California, Colorado, Connecticut, Delaware, Florida, Georgia, Hawaii, Idaho, Illinois, Indiana, Iowa, Kansas, Kentucky, Louisiana, Maine, Maryland, Massachusetts, Michigan, 

KeyboardInterrupt: 

In [243]:
accepted_draft_tokens

tensor([[  28, 1315, 5647,   28]])

In [192]:
probs[torch.arange(0, num_draft_tokens), torch.arange(0, num_draft_tokens)].shape

torch.Size([8, 49152])

In [209]:
draft_tokens

tensor([[  198,    49,  1901,  4232,    28, 11431,    28, 11947]])

In [197]:
probs = torch.tensor([
    [[0.3, 0.7], [0.5, 0.5]],
    [[0.4, 0.6], [0.9, 0.1]]
])

In [198]:
probs.shape

torch.Size([2, 2, 2])

In [199]:
probs[torch.LongTensor([0, 1]), torch.LongTensor([0, 1])]

tensor([[0.3000, 0.7000],
        [0.9000, 0.1000]])