In this example, we use vicuna-7b as the target model and vicuna-68m as the draft model.

In [None]:
# Load models
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

target_checkpoint = "lmsys/vicuna-7b-v1.3"
draft_checkpoint = "double7/vicuna-68m"

device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(target_checkpoint)
target_model = AutoModelForCausalLM.from_pretrained(target_checkpoint, device_map=device).eval()
draft_model = AutoModelForCausalLM.from_pretrained(draft_checkpoint, device_map=device).eval()

prompt= "Long long ago"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

It is recommended to set generation parameters in `model.generation_config` instead of passing them directly into the function `beam_search_by_SD`.

In [None]:
# Set generation parameters
max_new_tokens = 16
beam_size = 10
draft_beam_size = 10
gamma = 2

target_model.generation_config.update(**{
    "max_new_tokens": max_new_tokens,
    "num_beams": beam_size,
    "num_return_sequences": beam_size,
    "return_dict_in_generate": True,
})
draft_model.generation_config.update(**{
    "max_new_tokens": gamma,
    "num_beams": draft_beam_size,
    "num_return_sequences": draft_beam_size,
    "return_dict_in_generate": True,
})

In [3]:
from atspeed.beamsd4timing import Timer
with Timer() as timer_first:
    outputs = target_model.generate(**inputs)
print(f"First generation with resource preparation: {timer_first.time_cost:.2f} s")
print(tokenizer.batch_decode(outputs["sequences"], skip_special_tokens=True))

First generation with resource preparation: 1.71 s
['Long long ago, in a land far, far away, there was a beautiful princess named', 'Long long ago, in a land far, far away, there was a small village nestled', 'Long long ago, in a land far, far away, there was a kingdom that was ruled', 'Long long ago, in a galaxy far, far away, there was a time when the', 'Long long ago, in a galaxy far, far away, there was a small planet called', 'Long long ago, in a land far, far away, there was a beautiful princess who', 'Long long ago, in a galaxy far, far away, there was a time when Star', 'Long long ago, in a galaxy far, far away, there was a young man named', 'Long long ago, in a galaxy far far away, there was a planet called Earth.', 'Long long ago, in a galaxy far, far away, there was a planet called Earth']


In [4]:
with Timer() as timer_TF:
    outputs = target_model.generate(**inputs)
print(f"transformers (beam search by batch): {timer_TF.time_cost:.2f} s")
print(tokenizer.batch_decode(outputs["sequences"], skip_special_tokens=True))

transformers (beam search by batch): 1.15 s
['Long long ago, in a land far, far away, there was a beautiful princess named', 'Long long ago, in a land far, far away, there was a small village nestled', 'Long long ago, in a land far, far away, there was a kingdom that was ruled', 'Long long ago, in a galaxy far, far away, there was a time when the', 'Long long ago, in a galaxy far, far away, there was a small planet called', 'Long long ago, in a land far, far away, there was a beautiful princess who', 'Long long ago, in a galaxy far, far away, there was a time when Star', 'Long long ago, in a galaxy far, far away, there was a young man named', 'Long long ago, in a galaxy far far away, there was a planet called Earth.', 'Long long ago, in a galaxy far, far away, there was a planet called Earth']


In [5]:
from atspeed.beamsd_replace import replace_beam_search_with_TreeAttn

model = target_model
replace_beam_search_with_TreeAttn(model)
with Timer() as timer_TreeAttn:
    outputs = model.generate(**inputs)
print(f"atspeed (beam search by tree attention): {timer_TreeAttn.time_cost:.2f} s")
print(tokenizer.batch_decode(outputs["sequences"], skip_special_tokens=True))

atspeed (beam search by tree attention): 1.06 s
['Long long ago, in a land far, far away, there was a beautiful princess named', 'Long long ago, in a land far, far away, there was a small village nestled', 'Long long ago, in a land far, far away, there was a kingdom that was ruled', 'Long long ago, in a galaxy far, far away, there was a time when the', 'Long long ago, in a galaxy far, far away, there was a small planet called', 'Long long ago, in a land far, far away, there was a beautiful princess who', 'Long long ago, in a galaxy far, far away, there was a time when Star', 'Long long ago, in a galaxy far, far away, there was a young man named', 'Long long ago, in a galaxy far far away, there was a planet called Earth.', 'Long long ago, in a galaxy far, far away, there was a planet called Earth']


In [6]:
from atspeed.beamsd4timing import beam_search_by_SD_4timing

outputs = beam_search_by_SD_4timing(target_model, draft_model, inputs)
print(f"atspeed (beam search by Speculative Decoding): {outputs['time_cost']:.2f} s, accepted_steps: {outputs['total_accept_steps']}")
print(f"target_time_cost: {outputs['target_time_cost']:.2f} s, draft_time_cost: {outputs['draft_time_cost']:.2f} s, verify_time_cost: {outputs['verify_time_cost']:.2f} s")
print(tokenizer.batch_decode(outputs["beam_sequences"], skip_special_tokens=True))

atspeed (beam search by Speculative Decoding): 1.27 s, accepted_steps: 0
target_time_cost: 1.06 s, draft_time_cost: 0.10 s, verify_time_cost: 0.03 s
['Long long ago, in a land far, far away, there was a beautiful princess named', 'Long long ago, in a land far, far away, there was a small village nestled', 'Long long ago, in a land far, far away, there was a kingdom that was ruled', 'Long long ago, in a galaxy far, far away, there was a time when the', 'Long long ago, in a galaxy far, far away, there was a small planet called', 'Long long ago, in a land far, far away, there was a beautiful princess who', 'Long long ago, in a galaxy far, far away, there was a time when Star', 'Long long ago, in a galaxy far, far away, there was a young man named', 'Long long ago, in a galaxy far far away, there was a planet called Earth.', 'Long long ago, in a galaxy far, far away, there was a planet called Earth']


Note that when `gamma` and `max_new_tokens` are small, the time cost of target model forward in _atspeed (beam search by Speculative Decoding)_ is almost the same as in _atspeed (beam search by tree attention)_. This means that even if the accepted steps = 0, there can still be potential acceleration effect.