In [1]:
!pip install accelerate

[0m

In [2]:
import gc
import numpy as np
import time
from typing import Union

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast

In [3]:
def generate_candidate_tokens(
    input_ids: torch.Tensor, n_grams: torch.Tensor, ngrams_size: int, K: int
):
    # unfold the tensor into windows of `pattern_len + following_elements_count`
    window = input_ids.unfold(dimension=1, size=ngrams_size, step=1)
    # compare each window with the pattern (only the parts corresponding to the pattern)
    matching_window_indices = (window == n_grams).all(dim=2)
    # extract the indices where there are matches
    matching_indices = matching_window_indices.nonzero(as_tuple=True)[1]

    # find candidates with the longest length
    # based on: https://arxiv.org/pdf/2304.04487
    # we choose the candidate with the longest length at random if there are multiple candidates
    candidates = []
    max_length = K
    for idx in matching_indices:
        start_idx = idx + ngrams_size
        end_idx = start_idx + K
        candidate = input_ids[0, start_idx : min(end_idx, input_ids.size(1))]
        length = len(candidate)

        if length == max_length:
            candidates.append(candidate)
        else:
            # we do not consider prefix with no candidates
            if length > max_length:
                max_length = length
                candidates = [candidate]

    if candidates:
        chosen_candidate = candidates[np.random.randint(len(candidates))]
    else:
        chosen_candidate = torch.tensor([], dtype=torch.long, device=input_ids.device)

    return chosen_candidate.unsqueeze(dim=0)

In [4]:
@torch.no_grad()
def greedy_decoding(
    input_ids: torch.Tensor,
    model: torch.nn.Module,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    n: int = 400,
):
    eos_token_id = tokenizer.eos_token_id

    seq_len = input_ids.shape[1]
    T = seq_len + n

    while input_ids.shape[1] < T:
        logits = model(input_ids).logits
        next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(dim=1)], dim=1)
        yield next_token_id.item()
        if next_token_id == eos_token_id:
            break

    return input_ids

In [5]:
@torch.no_grad()
def ngram_decoding(
    input_ids: torch.Tensor,
    model: torch.nn.Module,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ngrams_size: int,
    K: int,
    n: int,
):
    eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
    eos_token_id_tensor = torch.tensor(
        [eos_token_id], dtype=torch.long, device=input_ids.device
    )
    seq_len = input_ids.shape[1]
    T = seq_len + n

    while input_ids.shape[1] < T:
        prefix = input_ids
        cur_len = input_ids.shape[1]

        # -----------------------------------------
        # Step 1: Generate N-grams
        # -----------------------------------------

        n_grams = input_ids[0, -ngrams_size:]

        # -----------------------------------------
        # Step 2: Generate K candidates tokens using the N-grams
        # -----------------------------------------

        candidate_tokens = generate_candidate_tokens(input_ids, n_grams, ngrams_size, K)

        # -----------------------------------------
        # Step 3: Validate the candidates using the LLM
        # -----------------------------------------

        # based on: https://arxiv.org/pdf/2304.04487
        # if we did not find any candidates tokens, we default to single-step decoding
        if candidate_tokens.shape[1] == 0:
            logits = model(input_ids).logits[:, -1, :]
            next_token = logits.argmax(dim=-1)
            input_ids = torch.cat([input_ids, next_token.unsqueeze(dim=0)], dim=1)
            yield (next_token.item(), False)
            if next_token.item() == eos_token_id:
                break
            continue

        prefix = torch.cat([input_ids, candidate_tokens], dim=1)
        # include the ngram_size + K + 1 in the logits
        logits = model(prefix).logits[:, cur_len - 1 : cur_len + ngrams_size + K, :]

        assert (
            logits.shape[1] == candidate_tokens.shape[1] + 1
        ), f"Expected logits shape: {ngrams_size + K + 1}, got: {logits.shape[1]}"

        selected_tokens = logits.argmax(dim=-1)
        # calculate the number of consecutive matching tokens between candidate_tokens and selected_tokens:
        # 1. Compare candidate_tokens with selected_tokens
        # 2. Invert the comparison result
        # 3. Calculate cumulative sum of mismatches
        # 4. Create a mask for positions before the first mismatch
        # 5. Sum up the mask to get the count of consecutive matches
        n_matches = (
            (~(candidate_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1
        ).sum()
        n_matches = min(n_matches, T - cur_len - 1)

        valid_tokens = selected_tokens[:, : n_matches + 1]
        # print("selected from prompt: ", tokenizer.decode(valid_tokens[0]))
        for token_id in valid_tokens[0]:
            yield (token_id.item(), True)
        input_ids = torch.cat([input_ids, valid_tokens], dim=1)

        if input_ids.shape[1] >= T:  # Check if we've reached the desired length
            break
        # we fulfill the condition of ngrams_size + K
        elif n_matches == ngrams_size + K:
            # we can take the last token from the logits and append it to the input_ids
            # we generated K+1 from the previous forward pass
            next_token = selected_tokens[-1]
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
            yield (next_token.item(), True)
            if next_token == eos_token_id:
                break

        if (valid_tokens == eos_token_id_tensor.item()).any():
            break

    return input_ids


In [6]:
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

print(DEVICE)

cuda


In [7]:
model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
        torch_dtype=torch.float16,
        device_map=DEVICE,
        use_cache=False,
    ).eval()

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    torch_dtype=torch.float16,
    device_map=DEVICE,
)

tokenizer.eos_token_id = 128009

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [8]:
input_str = """
def generate_candidate_tokens(
    input_ids: torch.Tensor, n_grams: torch.Tensor, ngrams_size: int, K: int
):
    # unfold the tensor into windows of `pattern_len + following_elements_count`
    window = input_ids.unfold(dimension=1, size=ngrams_size, step=1)
    # compare each window with the pattern (only the parts corresponding to the pattern)
    matching_window_indices = (window == n_grams).all(dim=2)
    # extract the indices where there are matches
    matching_indices = matching_window_indices.nonzero(as_tuple=True)[1]

    # find candidates with the longest length
    # based on: https://arxiv.org/pdf/2304.04487
    # we choose the candidate with the longest length at random if there are multiple candidates
    candidates = []
    max_length = K
    for idx in matching_indices:
        start_idx = idx + ngrams_size
        end_idx = start_idx + K
        candidate = input_ids[0, start_idx : min(end_idx, input_ids.size(1))]
        length = len(candidate)

        if length == max_length:
            candidates.append(candidate)
        else:
            # we do not consider prefix with no candidates
            if length > max_length:
                max_length = length
                candidates = [candidate]

    if candidates:
        chosen_candidate = candidates[np.random.randint(len(candidates))]
    else:
        chosen_candidate = torch.tensor([], dtype=torch.long, device=input_ids.device)

    return chosen_candidate.unsqueeze(dim=0)
"""
question = "Can you the variable name 'candidates' to 'candidates_tokens'?"
prompt = "<|start_header_id|>user<|end_header_id|>\nCode:```python\n{code_text}``` \n\n Question: {question} \n\n Modified code:\n<|start_header_id|>assistant<|end_header_id|>".format(
    code_text=input_str, question=question
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)

In [9]:
# warm-up run
print("Starting warm-up run")
ngram_decoding(input_ids, model, tokenizer, ngrams_size=3, K=10, n=50)
print("Warm-up complete.")

# actual run
print("\nNgram Decoding:")
torch.cuda.synchronize()
nd_start = time.perf_counter()
nd_output_ids = []
for token_id, speculated in ngram_decoding(
    input_ids, model, tokenizer, ngrams_size=3, K=10, n=400
):
    nd_output_ids.append(token_id)
    if speculated:
        print(
            f"\033[92m{tokenizer.decode(token_id)}\033[0m", end="", flush=True
        )
    else:
        print(
            tokenizer.decode(token_id, skip_special_tokens=True),
            end="",
            flush=True,
        )
torch.cuda.synchronize()
nd_end = time.perf_counter()
nd_time = nd_end - nd_start
print(
    f"\nTime taken: {nd_end - nd_start} seconds, {len(nd_output_ids) / nd_time} tokens/s"
)

Starting warm-up run
Warm-up complete.

Ngram Decoding:


Here is the modified code with the variable name[92m '[0m[92mcandidates[0m[92m'[0m[92m changed[0m to 'candidates[92m_tokens[0m[92m':

[0m```python
def generate_candidate[92m_tokens[0m[92m(
[0m[92m   [0m[92m input[0m[92m_ids[0m[92m:[0m[92m torch[0m[92m.Tensor[0m[92m,[0m[92m n[0m[92m_[0m[92mgrams[0m[92m:[0m[92m torch[0m[92m.Tensor[0m[92m,[0m[92m n[0m[92mgrams[0m[92m_size[0m[92m:[0m[92m int[0m[92m,[0m[92m K[0m[92m:[0m[92m int[0m[92m
[0m[92m):
[0m[92m   [0m[92m #[0m[92m unfold[0m[92m the[0m[92m tensor[0m[92m into[0m[92m windows[0m[92m of[0m[92m `[0m[92mpattern[0m[92m_len[0m[92m +[0m[92m following[0m[92m_elements[0m[92m_count[0m[92m`
[0m[92m   [0m[92m window[0m[92m =[0m[92m input[0m[92m_ids[0m[92m.un[0m[92mfold[0m[92m(d[0m[92mimension[0m[92m=[0m[92m1[0m[92m,[0m[92m size[0m[92m=n[0m[92mgrams[0m[92m_size[0m

In [10]:
# warm-up run
print("Starting warm-up run")
greedy_decoding(input_ids, model, tokenizer, n=50)
print("Warm-up complete.")

print("\nGreedy Decoding:")
torch.cuda.synchronize()
gd_start = time.perf_counter()
gd_output_ids = []
for token_id in greedy_decoding(input_ids, model, tokenizer, n=400):
    gd_output_ids.append(token_id)
    print(
        tokenizer.decode(token_id, skip_special_tokens=True), end="", flush=True
    )
torch.cuda.synchronize()
gd_end = time.perf_counter()
gd_time = gd_end - gd_start
print(
    f"\nTime taken: {gd_end - gd_start} seconds, {len(gd_output_ids) / gd_time} tokens/s"
)

Starting warm-up run
Warm-up complete.

Greedy Decoding:


Here is the modified code with the variable name 'candidates' changed to 'candidates_tokens':

```python
def generate_candidate_tokens(
    input_ids: torch.Tensor, n_grams: torch.Tensor, ngrams_size: int, K: int
):
    # unfold the tensor into windows of `pattern_len + following_elements_count`
    window = input_ids.unfold(dimension=1, size=ngrams_size, step=1)
    # compare each window with the pattern (only the parts corresponding to the pattern)
    matching_window_indices = (window == n_grams).all(dim=2)
    # extract the indices where there are matches
    matching_indices = matching_window_indices.nonzero(as_tuple=True)[1]

    # find candidates with the longest length
    # based on: https://arxiv.org/pdf/2304.04487
    # we choose the candidate with the longest length at random if there are multiple candidates
    candidates_tokens = []
    max_length = K
    for idx in matching_indices:
        start_idx = idx + ngra

In [11]:
gc.collect()
torch.cuda.empty_cache()