In [None]:
import torch
import torch.nn as nn
import sys
import json
from dataclasses import dataclass

sys.path.append("../../../open_lm")
from open_lm.model import Transformer
from open_lm.norms import RmsNorm

device = "cuda:0"
cfg = json.load(open("../model_configs/llama2_7b.json"))


@dataclass
class Params:
    dim: int
    n_layers: int
    n_heads: int
    vocab_size: int
    norm_eps: float
    seq_len: int
    post_embed_norm: bool
    weight_tying: bool
    norm_type: nn.Module = RmsNorm  # Make sure to use RmsNorm for LLaMA
    apply_qk_norm: bool = False
    positional_embedding_type: str = "llama_rotary"  # Make sure to set this for LLaMA
    ffn_type: str = "swiglu"


args = Params(
    dim=cfg["hidden_dim"],
    n_layers=cfg["n_layers"],
    n_heads=cfg["n_heads"],
    seq_len=cfg["seq_len"],
    vocab_size=cfg["vocab_size"],
    post_embed_norm=cfg["post_embed_norm"],
    weight_tying=cfg["weight_tying"],
    norm_eps=1e-5,
)

model = Transformer(args)
state_dict = torch.load("./LLAMA2/llama-2-7b/consolidated.00.converted.pth")
model.load_state_dict(state_dict, strict=True)
model = model.eval().to(device)

: 

In [None]:
sys.path.append("./LLAMA2/llama")
from llama.tokenizer import Tokenizer

tokenizer = Tokenizer("./LLAMA2/tokenizer.model")


def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token


def generate_top_p_language(prefix: str, temperature: float = 0.6, top_p: float = 0.9, max_len: int = 128):
    input_tokens = tokenizer.encode(prefix, bos=True, eos=False)
    tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)

    for i in range(max_len):
        with torch.no_grad():
            logits, _, _ = model(tokens)
        if temperature > 0:
            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
        tokens = torch.cat([tokens, next_token], dim=-1)

    generation = tokenizer.decode(tokens[0].cpu().numpy().tolist())
    return generation


prompts = [
    # For these prompts, the expected answer is the natural continuation of the prompt
    "I believe the meaning of life is",
    "Simply put, the theory of relativity states that ",
    """A brief message congratulating the team on the launch:

    Hi everyone,
    
    I just """,
    # Few shot prompt (providing a few examples before asking model to complete more);
    """Translate English to French:
    
    sea otter => loutre de mer
    peppermint => menthe poivrée
    plush girafe => girafe peluche
    cheese =>""",
    """He -> Him, She -> Her, They ->""",
    """Who is Donald Trump?""",
]

for prompt in prompts:
    print("====================================")
    generated_text = generate_top_p_language(prompt)
    print(prompt)
    print(generated_text)
    print("====================================")

: 

: 