In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from torch.cuda.amp import autocast
from typing import List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn.functional as F
from model import pretrainLlama
from argparse import Namespace

In [None]:

Role = Literal["system", "user", "assistant"]


class Message(TypedDict):
    role: Role
    content: str


class CompletionPrediction(TypedDict, total=False):
    generation: str
    tokens: List[str]  # not required
    logprobs: List[float]  # not required


class ChatPrediction(TypedDict, total=False):
    generation: Message
    tokens: List[str]  # not required
    logprobs: List[float]  # not required


Dialog = List[Message]

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."


class Llama:
    @staticmethod
    def build(
        max_seq_len: int,
        max_batch_size: int,
        model_parallel_size: Optional[int] = None,
    ) -> "Llama":
        checkpoint = torch.load(
            '/data/rozen/home/e0833634/lama/protllama/pl_model_cache/epoch=23-train_perplexity=1.161-val_perplexity=255.593-ppi_10_26_10k_2048.ckpt')
        hyper_parameters = checkpoint["hyper_parameters"]
        original_hparam = hyper_parameters['hparam']

        new_hparam = Namespace(
            accumulate_grad_batches=original_hparam.accumulate_grad_batches,
            attempts=original_hparam.attempts,
            batch_size=original_hparam.batch_size,
            date=original_hparam.date,
            devices=original_hparam.devices,
            epoch=original_hparam.epoch,
            flash_attention=original_hparam.flash_attention,
            hidden_size=original_hparam.hidden_size,
            input_dataset_path=original_hparam.input_dataset_path,
            intermediate_size=original_hparam.intermediate_size,
            learning_rate=original_hparam.learning_rate,
            max_position_embeddings=original_hparam.max_position_embeddings,
            num_attention_heads=original_hparam.num_attention_heads,
            num_hidden_layers=original_hparam.num_hidden_layers,
            num_key_value_heads=original_hparam.num_key_value_heads,
            num_workers=original_hparam.num_workers,
            output_dataset_path=original_hparam.output_dataset_path,
            save_top_k=original_hparam.save_top_k,
            scheduler=original_hparam.scheduler,
            strategy=original_hparam.strategy,
            target=original_hparam.target,
            tokenizer_path='/data/rozen/home/e0833634/lama/protllama/batch_script/',  # Update the tokenizer_path here
            train_dataloader_length=original_hparam.train_dataloader_length,
            vocab_size=original_hparam.vocab_size,

            max_batch_size=max_batch_size,
            max_seq_len=max_seq_len
            )

        # Update the hyper_parameters with the new Namespace
        hyper_parameters['hparam'] = new_hparam
        model = pretrainLlama(**hyper_parameters)
        model.configure_model()
        state_dict = checkpoint['state_dict']
        model.load_state_dict(state_dict)
        model = model.cuda()
        tokenizer = model.tokenizer

        return Llama(model, tokenizer)

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    @torch.inference_mode()
    def generate(
        self,
        prompt_tokens: List[List[int]],
        max_gen_len: int,
        temperature: float = 0.6,
        top_p: float = 0.9,
        logprobs: bool = False,
        echo: bool = False,
    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
        params = self.model.hparam
        bsz = len(prompt_tokens)
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        min_prompt_len = min(len(t) for t in prompt_tokens)
        max_prompt_len = max(len(t) for t in prompt_tokens)
        assert max_prompt_len <= params.max_seq_len
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

        pad_id = self.tokenizer.unk_id() #original pad is -1, make it equals to unk to make the id to 0
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
        if logprobs:
            token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

        prev_pos = 0
        eos_reached = torch.tensor([False] * bsz, device="cuda")
        input_text_mask = tokens != pad_id
        if min_prompt_len == total_len:
            #logits = self.model.forward(tokens, prev_pos)
            logits = self.model.forward(tokens)
            token_logprobs = -F.cross_entropy(
                input=logits.transpose(1, 2),
                target=tokens,
                reduction="none",
                ignore_index=pad_id,
            )

        for cur_pos in range(min_prompt_len, total_len):
            with autocast():
                logits = self.model.forward(input_ids=tokens[:, prev_pos:cur_pos])[0]
            if temperature > 0:
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits[:, -1], dim=-1)

            next_token = next_token.reshape(-1)
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            if logprobs:
                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
                    input=logits.transpose(1, 2),
                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
                    reduction="none",
                    ignore_index=pad_id,
                )
            eos_reached |= (~input_text_mask[:, cur_pos]) & (
                next_token == self.tokenizer.eos_id()
            )
            prev_pos = cur_pos
            if all(eos_reached):
                break

        if logprobs:
            token_logprobs = token_logprobs.tolist()
        out_tokens, out_logprobs = [], []
        for i, toks in enumerate(tokens.tolist()):
            # cut to max gen len
            start = 0 if echo else len(prompt_tokens[i])
            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
            probs = None
            if logprobs:
                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            #if self.tokenizer.eos_id() in toks:
                #eos_idx = toks.index(self.tokenizer.eos_id())
                #toks = toks[:eos_idx]
                #probs = probs[:eos_idx] if logprobs else None
            out_tokens.append(toks)
            out_logprobs.append(probs)
        return (out_tokens, out_logprobs if logprobs else None)

    def text_completion(
        self,
        prompts: List[str],
        temperature: float = 0.6,
        top_p: float = 0.9,
        max_gen_len: Optional[int] = None,
        logprobs: bool = False,
        echo: bool = False,
    ) -> List[CompletionPrediction]:
        if max_gen_len is None:
            max_gen_len = self.model.hparam.max_seq_len - 1
        #prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=True) for x in prompts]
        prompt_tokens = []
        for x in prompts:
            t = [self.tokenizer.bos_id()] + self.tokenizer.encode(x) + [self.tokenizer.eos_id()]
            prompt_tokens.append(t)
        generation_tokens, generation_logprobs = self.generate(
            prompt_tokens=prompt_tokens,
            max_gen_len=max_gen_len,
            temperature=temperature,
            top_p=top_p,
            logprobs=logprobs,
            echo=echo,
        )
        if logprobs:
            return [
                {
                    "generation": self.tokenizer.decode(t),
                    "tokens": [self.tokenizer.decode(x) for x in t],
                    "logprobs": logprobs_i,
                }
                for t, logprobs_i in zip(generation_tokens, generation_logprobs)
            ]
        return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]

def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

In [None]:
generator = Llama.build(
    max_seq_len=2048,
    max_batch_size=2,
    )

In [None]:
prompts: List[str] = [
        # For these prompts, the expected answer is the natural continuation of the prompt
    "YAPSALVLTVGKGVSATTAAPERAVTLTCAPGPSGTHPAAGSACADLAAVGGDLNALTRGEDVMCPMVYDPVLLTVDGVWQGKRVSYERVFSNECEMNAHGSSVFAF",
    "DFVLDNEGNPLENGGTYYILSDITAFGGIRAAPTGNERCPLTVVQSRNELDKGIGTIISSPYRIRFIAEGHPLSLKFDSFAVIMLCVGIPTEWSVVEDLPEGPAVKIGENKDAMDGWFRLERVSDDEFNNYKLVFCPQKCGDIGISIDHDDGTRRLVVSKNKPLVVQFQKLD"
        # Few shot prompt (providing a few examples before asking model to complete more);
        #"""Translate English to French:

        #sea otter => loutre de mer
        #peppermint => menthe poivrée
        #plush girafe => girafe peluche
        #cheese =>""",
    ]

In [None]:
results = generator.text_completion(
        prompts,
        max_gen_len=64,
        temperature=0.6,
        top_p=0.9,
        echo=True
    )

In [None]:
for prompt, result in zip(prompts, results):
    print(prompt)
    print(f"> {result['generation']}")
    print("\n==================================\n")