In [8]:
from pathlib import Path

NUM_SAMPLES = 1
MAX_NEW_TOKENS= 50
TOP_K = 200
TEMPERATURE = 0.8
ACCELERATOR = "AUTO"
MODEL_SIZE = "7B"
QUANTIZE = True

ROOT_DIR = "/app"
CHECKPOINT_PATH = Path(f"{ROOT_DIR}/lit-llama/checkpoints/lit-llama/{MODEL_SIZE}/state_dict.pth")
TOKENIZER_PATH = Path("{ROOT_DIR}/lit-llama/checkpoints/lit-llama/tokenizer.model")


# Import some libraries

In [4]:
import sys
import time
from pathlib import Path
from typing import Optional

import lightning as L
import torch
from lit_llama import LLaMA, Tokenizer, as_8_bit_quantized

In [5]:
@torch.no_grad()
def generate(
    model: torch.nn.Module,
    idx: torch.Tensor,
    max_new_tokens: int,
    max_seq_length: int,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
    The implementation of this function is modified from A. Karpathy's nanoGPT.
    Args:
        model: The model to use.
        idx: Tensor of shape (B, T) with indices of the prompt sequence.
        max_new_tokens: The number of new tokens to generate.
        max_seq_length: The maximum sequence length allowed.
        temperature: Scales the predicted logits by 1 / temperature
        top_k: If specified, only sample among the tokens with the k highest probabilities
    """
    # create an empty tensor of the expected final shape and fill in the current tokens
    B, T = idx.shape
    T_new = T + max_new_tokens
    empty = torch.empty(B, T_new, dtype=idx.dtype, device=idx.device)
    empty[:, :T] = idx
    idx = empty

    # generate max_new_tokens tokens
    for t in range(T, T_new):
        # ignore the not-filled-yet tokens
        idx_cond = idx[:, :t]
        # if the sequence context is growing too long we must crop it at max_seq_length
        idx_cond = idx_cond if T <= max_seq_length else idx_cond[:, -max_seq_length:]

        # forward
        logits = model(idx_cond)
        logits = logits[:, -1] / temperature

        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float("Inf")

        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        # concatenate the new column
        idx[:, t:] = idx_next

    return idx