In [1]:
from traceback import print_exc
from inference import LLaMA
from tqdm import tqdm
import torch

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LLaMA.build_model(
    checkpoints_dir="Llama-2-7b/",
    tokenizer_path="Llama-2-7b/tokenizer.model",
    load_model=True,
    max_seq_len=1024,
    max_batch_size=4,
    device=device
)

Loading checkpoint from Llama-2-7b\consolidated.00.pth
Loaded checkpoint in 64.50s


  _C._set_default_tensor_type(t)


Loaded model in 233.16s


In [3]:
(model.model.layers[0].attention.cache_k == 0).all(dim=-1).all(dim=-1).all(dim=-1)

tensor([True, True, True, True])

In [4]:
max_new_tokens = 4
prompts = [
        "Roses are red, violets are",
        "7 + 5 =",
        """Complete the blank:

Q: The capital of France is ____
A: Paris

Q: The capital of Germany is ____
A: Berlin

Q: The capital of Italy is ____
A: Rome

Q: The capital of Spain is ____
A:"""
    ]

prompts = [
        "Roses are red, violets are"
    ]

In [5]:
strategy = "top_p"
try:
    print(f"{strategy} sampling")
    out_tokens, out_texts = model.generate(prompts, max_gen_len=max_new_tokens, strategy=strategy, p=0.9)
    assert len(out_texts) == len(prompts), f"Expected {len(prompts)} outputs, got {len(out_tokens)}"
    print(f"\n{'-'*50}\n".join(out_texts))
    print("="*100)
except Exception as e:
    print(f"Error during {strategy} sampling: {e}")
    print_exc()

strategy = "greedy"
try:
    print(f"{strategy} sampling")
    out_tokens, out_texts = model.generate(prompts, max_gen_len=max_new_tokens, strategy=strategy)
    assert len(out_texts) == len(prompts), f"Expected {len(prompts)} outputs, got {len(out_tokens)}"
    print(f"\n{'-'*50}\n".join(out_texts))
    print("="*100)
except Exception as e:
    print(f"Error during {strategy} sampling: {e}")
    print_exc()

strategy = "random"
try:
    print(f"{strategy} sampling")
    out_tokens, out_texts = model.generate(prompts, max_gen_len=max_new_tokens, strategy=strategy)
    assert len(out_texts) == len(prompts), f"Expected {len(prompts)} outputs, got {len(out_tokens)}"
    print(f"\n{'-'*50}\n".join(out_texts))
    print("="*100)
except Exception as e:
    print(f"Error during {strategy} sampling: {e}")
    print_exc()

strategy = "top_k"
try:
    print(f"{strategy} sampling")
    out_tokens, out_texts = model.generate(prompts, max_gen_len=max_new_tokens, strategy=strategy, k=50)
    assert len(out_texts) == len(prompts), f"Expected {len(prompts)} outputs, got {len(out_tokens)}"
    print(f"\n{'-'*50}\n".join(out_texts))
    print("="*100)
except Exception as e:
    print(f"Error during {strategy} sampling: {e}")
    print_exc()

top_p sampling


Generating tokens: 100%|██████████| 12/12 [03:19<00:00, 16.59s/it]


Roses are red, violets are blue, take one
greedy sampling


Generating tokens: 100%|██████████| 12/12 [03:17<00:00, 16.48s/it]


Roses are red, violets are blue, I'
random sampling


Generating tokens: 100%|██████████| 12/12 [03:15<00:00, 16.27s/it]


Roses are red, violets are blue, come be
top_k sampling


Generating tokens: 100%|██████████| 12/12 [03:15<00:00, 16.28s/it]

Roses are red, violets are ded Ad amgra





In [55]:
def _beam1(
    model,
    tokens: torch.Tensor,                 # (B, L_total)
    prompt_tokens_mask: torch.Tensor,     # (B, L_total)  – True on prompt positions
    temperature: float,
    k: int,
) -> tuple[list[list[int]], list[str]]:

    device          = tokens.device
    B, L_total      = tokens.shape
    vocab_size      = model.tokenizer.vocab_size()
    eos_id          = model.tokenizer.eos_id

    # ------------------------------------------------------------------
    # Duplicate everything k times ------------------------------------------------
    # ------------------------------------------------------------------
    tokens          = tokens.unsqueeze(1).repeat(1, k, 1)                 # (B, k, L)
    prompt_mask     = prompt_tokens_mask.unsqueeze(1).repeat(1, k, 1)     # (B, k, L)

    tokens_flat     = tokens.view(B * k, L_total)                         # (B·k, L)
    prompt_flat     = prompt_mask.view(B * k, L_total)

    # running log‑probs for every beam (B, k)
    beam_scores     = torch.zeros(B, k, device=device)                    # log‑probs
    finished        = torch.zeros(B, k, dtype=torch.bool, device=device)  # True if <eos> seen

    for cur_pos in range(1, L_total):

        # ------------------------------------------------------------------
        # 1. Forward last token for every live beam --------------------------
        # ------------------------------------------------------------------
        with torch.no_grad():
            logits = model.model(tokens_flat[:, cur_pos-1:cur_pos], cur_pos)  # (B·k, 1, V)

        log_probs = torch.log_softmax(logits[:, -1] / temperature, dim=-1)    # (B·k, V)
        log_probs = log_probs.view(B, k, vocab_size)                          # (B, k, V)

        # ------------------------------------------------------------------
        # 2. Handle (a) prompt positions, (b) finished beams -----------------
        # ------------------------------------------------------------------
        # a) Padded prompt positions → force given token (prob=0 → log=0)
        forced_mask         = prompt_mask[:, :, cur_pos]                      # (B, k)
        forced_token_ids    = tokens[:, :, cur_pos]                           # (B, k)

        log_probs[forced_mask]  = float("-inf")
        log_probs[forced_mask, torch.arange(vocab_size, device=device)]  # type: ignore
        log_probs[forced_mask, forced_token_ids] = 0.0                       # keep only the true token

        # b) Finished beams → ONLY allow <eos>
        log_probs[finished] = float("-inf")
        log_probs[finished, eos_id] = 0.0

        # ------------------------------------------------------------------
        # 3. Expand & select top‑k per original prompt -----------------------
        # ------------------------------------------------------------------
        # current beam_scores (B, k)  +  log_probs (B, k, V)  →  (B, k, V)
        candidate_scores = beam_scores.unsqueeze(-1) + log_probs            # (B, k, V)
        candidate_scores = candidate_scores.view(B, k * vocab_size)         # (B, k·V)

        topk_scores, topk_indices = candidate_scores.topk(k, dim=-1)        # (B, k)

        topk_beam_idx  = topk_indices // vocab_size                         # (B, k)
        topk_token_idx = topk_indices %  vocab_size                         # (B, k)

        # ------------------------------------------------------------------
        # 4. Gather tokens from the parent beams and write the new token -----
        # ------------------------------------------------------------------
        batch_idx = torch.arange(B, device=device).unsqueeze(1)             # (B, 1)

        # gather the chosen parent beams
        tokens = tokens[batch_idx, topk_beam_idx]                           # (B, k, L)
        prompt_mask = prompt_mask[batch_idx, topk_beam_idx]                 # (B, k, L)
        finished = finished[batch_idx, topk_beam_idx]                       # (B, k)

        # write the new token
        tokens[:, :, cur_pos] = topk_token_idx
        prompt_mask[:, :, cur_pos] = False

        # update score & finished flags
        beam_scores = topk_scores
        finished |= (topk_token_idx == eos_id)

        # flatten again for next forward pass
        tokens_flat = tokens.view(B * k, L_total)
        prompt_flat = prompt_mask.view(B * k, L_total)

        # early stop if every prompt has at least one finished beam
        if finished.all():  
            break

    # ----------------------------------------------------------------------
    # 5. Pick the best beam per prompt -------------------------------------
    # ----------------------------------------------------------------------
    best_scores, best_idx = beam_scores.max(dim=-1)                         # (B,)
    best_tokens = tokens[batch_idx.squeeze(), best_idx]                     # (B, L)

    out_tokens: list[list[int]] = []
    out_text:   list[str]       = []

    for seq in best_tokens:                                                 # seq: (L,)
        seq = seq.tolist()
        if eos_id in seq:
            seq = seq[:seq.index(eos_id)]
        out_tokens.append(seq)
        out_text.append(model.tokenizer.decode(seq))

    return out_tokens, out_text

def _beam2(model, tokens: torch.Tensor, prompt_tokens_mask: torch.Tensor, temperature: float, k: int) -> tuple[list[list[int]], list[str]]:
    device = model.model_args.device
    batch_size, total_len = tokens.shape
    vocab_size = model.model_args.vocab_size

    # Expand the tokens and masks to k beams per batch element
    expanded_tokens = tokens.repeat_interleave(k, dim=0)  # (batch_size * k, total_len)
    expanded_prompt_mask = prompt_tokens_mask.repeat_interleave(k, dim=0)  # (batch_size * k, total_len)

    # Initialize beam scores: (batch_size, k)
    beam_scores = torch.zeros((batch_size, k), device=device)
    beam_scores[:, 1:] = -float('inf')  # Only the first beam is active initially

    # Track whether each beam has reached EOS
    eos_reached = torch.zeros((batch_size, k), dtype=torch.bool, device=device)

    for cur_pos in tqdm(range(1, total_len), desc="Generating tokens (beam search)"):
        # Get the logits for the next token
        with torch.no_grad():
            # Input is the tokens up to cur_pos-1
            logits = model.model(expanded_tokens[:, cur_pos-1:cur_pos], cur_pos)

        # Apply temperature and get log probabilities
        logits = logits[:, -1] / temperature
        log_probs = torch.log_softmax(logits, dim=-1)  # (batch_size * k, vocab_size)

        # For each beam, select top k tokens and their log probs
        topk_log_probs, topk_indices = log_probs.topk(k, dim=-1)  # (batch_size * k, k)

        # Reshape to (batch_size, k, k)
        topk_log_probs = topk_log_probs.view(batch_size, k, k)

        # Compute new scores: beam_scores (batch, k, 1) + topk_log_probs (batch, k, k) => (batch, k, k)
        new_scores = beam_scores.unsqueeze(-1) + topk_log_probs

        # Flatten to (batch_size, k * k)
        new_scores_flat = new_scores.view(batch_size, -1)

        # Select top k candidates for each batch element
        topk_new_scores, topk_indices_flat = new_scores_flat.topk(k, dim=-1)  # (batch_size, k)

        # Determine parent beams and token ranks
        parent_beams = topk_indices_flat // k  # (batch_size, k)
        token_ranks = topk_indices_flat % k    # (batch_size, k)

        # Update beam scores
        beam_scores = topk_new_scores

        # Gather the token indices for the selected candidates
        beam_indices = (torch.arange(batch_size, device=device).unsqueeze(1) * k + parent_beams).view(-1)
        token_ranks = token_ranks.view(-1)

        # Gather the token IDs from topk_indices
        gathered_token_ids = topk_indices[beam_indices, token_ranks]

        # Apply prompt mask: if current position is part of the prompt, use the existing token
        current_prompt_mask = expanded_prompt_mask[:, cur_pos]
        existing_tokens = expanded_tokens[:, cur_pos]
        gathered_token_ids = torch.where(current_prompt_mask, existing_tokens, gathered_token_ids)

        # Update the expanded_tokens by copying the parent sequences and adding the new token
        expanded_tokens[:, :cur_pos] = expanded_tokens[beam_indices, :cur_pos]
        expanded_tokens[:, cur_pos] = gathered_token_ids

        # Check for EOS tokens in non-prompt positions
        eos_reached_new = (~current_prompt_mask) & (gathered_token_ids == model.tokenizer.eos_id)
        eos_reached_new = eos_reached_new.view(batch_size, k)

        # Update eos_reached and mask scores for completed beams
        eos_reached = eos_reached | eos_reached_new
        beam_scores[eos_reached] = -float('inf')

        # Early stopping if all beams are EOS
        if eos_reached.all():
            break

    # Select the best beam for each batch element
    best_beam_indices = beam_scores.argmax(dim=-1)  # (batch_size)
    best_beam_indices_expanded = (torch.arange(batch_size, device=device) * k) + best_beam_indices
    best_tokens = expanded_tokens[best_beam_indices_expanded]

    # Process the final tokens to remove padding and cut at EOS
    out_tokens = []
    out_text = []
    for i in range(batch_size):
        tokens_list = best_tokens[i].tolist()
        if model.tokenizer.eos_id in tokens_list:
            eos_idx = tokens_list.index(model.tokenizer.eos_id)
            tokens_list = tokens_list[:eos_idx]
        out_tokens.append(tokens_list)
        out_text.append(model.tokenizer.decode(tokens_list))

    return out_tokens, out_text

def generate_with_beam(model, prompts: list[str], temperature: float = 1.0, max_gen_len = None, k: int = 2, func: int = 1) -> tuple[list[list[int]], list[str]]:
    assert temperature > 0, f"Temperature must be greater than 0, got {temperature}"
    if max_gen_len is None:
        max_gen_len = model.model_args.max_seq_len - 1
    # Convert each prompt into tokens
    prompt_tokens = [model.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts]
    # Make sure the batch size is not too large
    batch_size = len(prompt_tokens)
    assert batch_size <= model.model_args.max_batch_size, f"batch size must be less than or equal to {model.model_args.max_batch_size}"
    max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
    # Make sure the prompt length is not larger than the maximum sequence length
    assert max_prompt_len <= model.model_args.max_seq_len, f"prompt length must be less than or equal to {model.model_args.max_seq_len}"
    total_len = min(model.model_args.max_seq_len, max_gen_len + max_prompt_len)

    # Create the list that will contain the generated tokens, along with the initial prompt tokens
    pad_id = model.tokenizer.pad_id()
    tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=model.model_args.device)
    for k, t in enumerate(prompt_tokens):
        # Populate the initial tokens with the prompt tokens
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=model.model_args.device)
    
    eos_reached = torch.tensor([False] * batch_size, device=model.model_args.device)
    prompt_tokens_mask = tokens != pad_id # True if the token is a prompt token, False otherwise

    return _beam1(model, tokens, prompt_tokens_mask, temperature, k) if func == 1 else _beam2(model, tokens, prompt_tokens_mask, temperature, k)

In [56]:
generate_with_beam(model, prompts, max_gen_len=max_new_tokens, k=2, func=1)

RuntimeError: cannot reshape tensor of 0 elements into shape [0, 1, 32, -1, 2] because the unspecified dimension size -1 can be any value and is ambiguous

In [57]:
generate_with_beam(model, prompts, max_gen_len=max_new_tokens, k=2, func=2)

Generating tokens (beam search):   0%|          | 0/12 [00:00<?, ?it/s]


RuntimeError: cannot reshape tensor of 0 elements into shape [0, 1, 32, -1, 2] because the unspecified dimension size -1 can be any value and is ambiguous