In [1]:
from traceback import print_exc
from inference import LLaMA
from tqdm import tqdm
import torch

from IPython import embed

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LLaMA.build_model(
    checkpoints_dir="Llama-2-7b/",
    tokenizer_path="Llama-2-7b/tokenizer.model",
    load_model=True,
    max_seq_len=1024,
    max_batch_size=1,
    device=device
)

Loading checkpoint from Llama-2-7b\consolidated.00.pth
Loaded checkpoint in 63.11s


  _C._set_default_tensor_type(t)


Loaded model in 210.47s


In [18]:
(model.model.layers[0].attention.cache_k == 0).all(dim=-1).all(dim=-1).all(dim=-1)

tensor([True])

In [19]:
max_new_tokens = 1
prompts = [
        "7 + 5 ="
    ]

In [20]:
strategy = "top_k"
try:
    print(f"{strategy} sampling")
    out_tokens, out_texts = model.generate(prompts, max_gen_len=max_new_tokens, strategy=strategy, k=50)
    assert len(out_texts) == len(prompts), f"Expected {len(prompts)} outputs, got {len(out_tokens)}"
    print(f"\n{'-'*50}\n".join(out_texts))
    print("="*100)
except Exception as e:
    print(f"Error during {strategy} sampling: {e}")
    print_exc()

top_k sampling


Generating tokens using top_k with k=50:   0%|          | 0/7 [00:00<?, ?it/s]

Generating tokens using top_k with k=50: 100%|██████████| 7/7 [02:03<00:00, 17.64s/it]

7 + 5 =?





In [None]:
def _beam(model, tokens: torch.Tensor, prompt_tokens_mask: torch.Tensor, temperature: float, k: int) -> tuple[list[list[int]], list[str]]:
    assert k > 1, "Beam size must be greater than 1 for beam search."
    batch_size, total_len = tokens.shape

    embed()
    return

    # Expand the tokens and masks to k beams per batch element
    expanded_tokens = tokens.repeat_interleave(k, dim=0)  # (batch_size * k, total_len)
    expanded_prompt_mask = prompt_tokens_mask.repeat_interleave(k, dim=0)  # (batch_size * k, total_len)

    # Initialize beam scores: (batch_size, k)
    beam_scores = torch.zeros((batch_size, k), device=model.model_args.device)
    beam_scores[:, 1:] = -float("inf")  # Only the first beam is active initially

    # Track whether each beam has reached EOS
    eos_reached = torch.zeros((batch_size, k), dtype=torch.bool, device=model.model_args.device)

    for cur_pos in tqdm(range(1, total_len), desc="Generating tokens (beam search)"):
        # Get the logits for the next token
        with torch.no_grad():
            # Input is the tokens up to cur_pos-1
            logits = model.model(expanded_tokens[:, cur_pos-1:cur_pos], cur_pos)

        # Apply temperature to logits
        logits = logits[:, -1] / temperature
        # use log probs instead of probs for numerical stability and to avoid underflow
        log_probs = torch.log_softmax(logits, dim=-1)  # (batch_size * k, vocab_size)

        # For each beam, select top k tokens and their log probs
        topk_log_probs, topk_indices = log_probs.topk(k, dim=-1) # (batch_size * k, k)

        # Reshape to (batch_size, k, k)
        topk_log_probs = topk_log_probs.view(batch_size, k, k)

        # Compute new scores: beam_scores (batch, k, 1) + topk_log_probs (batch, k, k) => (batch, k, k)
        new_scores = beam_scores.unsqueeze(-1) + topk_log_probs

        # Flatten to (batch_size, k * k)
        new_scores_flat = new_scores.view(batch_size, -1)

        # Select top k candidates for each batch element
        topk_new_scores, topk_indices_flat = new_scores_flat.topk(k, dim=-1)  # (batch_size, k)

        # Determine parent beams and token ranks
        parent_beams = topk_indices_flat // k  # (batch_size, k)
        token_ranks = topk_indices_flat % k    # (batch_size, k)

        # Update beam scores
        beam_scores = topk_new_scores

        # Gather the token indices for the selected candidates
        beam_indices = (torch.arange(batch_size, device=model.model_args.device).unsqueeze(1) * k + parent_beams).view(-1)
        token_ranks = token_ranks.view(-1)

        # Gather the token IDs from topk_indices
        gathered_token_ids = topk_indices[beam_indices, token_ranks]

        # Apply prompt mask: if current position is part of the prompt, use the existing token
        current_prompt_mask = expanded_prompt_mask[:, cur_pos]
        existing_tokens = expanded_tokens[:, cur_pos]
        gathered_token_ids = torch.where(current_prompt_mask, existing_tokens, gathered_token_ids)

        # Update the expanded_tokens by copying the parent sequences and adding the new token
        expanded_tokens[:, :cur_pos] = expanded_tokens[beam_indices, :cur_pos]
        expanded_tokens[:, cur_pos] = gathered_token_ids

        # Check for EOS tokens in non-prompt positions
        eos_reached_new = (~current_prompt_mask) & (gathered_token_ids == model.tokenizer.eos_id)
        eos_reached_new = eos_reached_new.view(batch_size, k)

        # Update eos_reached and mask scores for completed beams
        eos_reached |= eos_reached_new
        beam_scores[eos_reached] = -float("inf")

        # Early stopping if all beams reached EOS
        if eos_reached.all():
            break

    # Select the best beam for each batch element
    best_beam_indices = beam_scores.argmax(dim=-1)  # (batch_size)
    best_beam_indices_expanded = (torch.arange(batch_size, device=model.model_args.device) * k) + best_beam_indices
    best_tokens = expanded_tokens[best_beam_indices_expanded]

    # Process the final tokens to remove padding and cut at EOS
    out_tokens = []
    out_text = []
    for i in range(batch_size):
        tokens_list = best_tokens[i].tolist()
        if model.tokenizer.eos_id in tokens_list:
            eos_idx = tokens_list.index(model.tokenizer.eos_id)
            tokens_list = tokens_list[:eos_idx]
        out_tokens.append(tokens_list)
        out_text.append(model.tokenizer.decode(tokens_list))

    return out_tokens, out_text

def generate_with_beam(model, prompts: list[str], temperature: float = 1.0, max_gen_len = None, beam_k: int = 2, func: int = 1) -> tuple[list[list[int]], list[str]]:
    assert temperature > 0, f"Temperature must be greater than 0, got {temperature}"
    if max_gen_len is None:
        max_gen_len = model.model_args.max_seq_len - 1
    # Convert each prompt into tokens
    prompt_tokens = [model.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts]
    # Make sure the batch size is not too large
    batch_size = len(prompt_tokens)
    assert batch_size <= model.model_args.max_batch_size, f"batch size must be less than or equal to {model.model_args.max_batch_size}"
    max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
    # Make sure the prompt length is not larger than the maximum sequence length
    assert max_prompt_len <= model.model_args.max_seq_len, f"prompt length must be less than or equal to {model.model_args.max_seq_len}"
    total_len = min(model.model_args.max_seq_len, max_gen_len + max_prompt_len)

    # Create the list that will contain the generated tokens, along with the initial prompt tokens
    pad_id = model.tokenizer.pad_id()
    tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=model.model_args.device)
    for k, t in enumerate(prompt_tokens):
        # Populate the initial tokens with the prompt tokens
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=model.model_args.device)
    
    eos_reached = torch.tensor([False] * batch_size, device=model.model_args.device)
    prompt_tokens_mask = tokens != pad_id # True if the token is a prompt token, False otherwise

    return _beam(model, tokens, prompt_tokens_mask, temperature, beam_k)

In [92]:
generate_with_beam(model, prompts, max_gen_len=max_new_tokens, beam_k=2, func=2)

Python 3.12.9 | packaged by conda-forge | (main, Mar  4 2025, 22:37:18) [MSC v.1943 64 bit (AMD64)]
Type 'copyright', 'credits' or 'license' for more information
IPython 9.1.0 -- An enhanced Interactive Python. Type '?' for help.
Tip: You can use `files = !ls *.png`



Out[1]: 
tensor([[    1,  5678,   267,   526,  2654, 29892, 28008, 10376,   526,    -1,
            -1,    -1,    -1],
        [    1, 29871, 29955,   718, 29871, 29945,   353,    -1,    -1,    -1,
            -1,    -1,    -1]])


Out[3]: 
tensor([[    1,  5678,   267,   526,  2654, 29892, 28008, 10376,   526,    -1,
            -1,    -1,    -1],
        [    1,  5678,   267,   526,  2654, 29892, 28008, 10376,   526,    -1,
            -1,    -1,    -1],
        [    1, 29871, 29955,   718, 29871, 29945,   353,    -1,    -1,    -1,
            -1,    -1,    -1],
        [    1, 29871, 29955,   718, 29871, 29945,   353,    -1,    -1,    -1,
            -1,    -1,    -1]])


Out[5]: 
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False,
         False, False, False]])

Out[6]: 
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, Fals