In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, MistralForCausalLM
import time
import torch

device = "cuda" # the device to load the model onto

model: MistralForCausalLM = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.float16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")

In [25]:
@torch.no_grad()
def sample_autoreg(
    token_ids: torch.Tensor,
    model: torch.nn.Module,
    device="cuda",
    num_tokens=128,
):
    len_prefix = token_ids.shape[1]
    t = torch.cat((token_ids, torch.full((1, num_tokens), 0, dtype=token_ids.dtype)), dim=-1).to(device)
    for i in range(num_tokens):
        new_token_id = model(t[:, :len_prefix + i]).logits[0, -1, :].argmax()
        t[0, len_prefix+i] = new_token_id.item()

    return t[0, len_prefix:].reshape((1, -1))

start = time.time()
generated_ids = sample_autoreg(
    encodeds,
    model,
    "cuda",
    16,
)
print(f"{time.time() - start:.2f}s")

decoded = tokenizer.batch_decode(generated_ids)
# print(decoded[0][:12])
generated_ids

0.57s


tensor([[ 5592, 28725,   315,   541,  5785,  1316,   368,   395,   264, 11495,
           993,  7136,   864, 13405, 28723,  4003]], device='cuda:0')

In [27]:
@torch.no_grad()
def sample_jacobi_decode(
    token_ids: torch.Tensor,
    model: torch.nn.Module,
    device="cuda",
    num_tokens=128,
    num_extra=3,
):
    assert token_ids.shape[0] == 1

    len_prefix = token_ids.shape[1]
    t = torch.cat((token_ids, torch.full((1, num_tokens), 0, dtype=token_ids.dtype)), dim=-1).to(device)

    i = len_prefix  # write index
    while i < len_prefix + num_tokens:
        n = min(num_extra, len_prefix + num_tokens - i - 1)

        # forward pass
        indices = model(t[:, :i + n]).logits[0, -(n+1):, :].argmax(dim=-1)

        # comparison
        j = 0
        nhits = 0
        for j in range(n):
            if indices[j].item() != t[0, i + j].item():
                break
            nhits += 1

        # update guesses for next round 👍
        t[0, i:i+n+1] = indices

        i += nhits + 1

    return t[0, len_prefix:].reshape((1, -1))

start = time.time()
generated_ids = sample_jacobi_decode(
    encodeds,
    model,
    "cuda",
    num_tokens=128,
    num_extra=1,
)
print(f"{time.time() - start:.2f}s")

decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0][:12])

118.14s
Yes, I can c
