In [None]:
!pip install trl[peft] --quiet
!pip install bitsandbytes loralib --quiet

In [None]:
import numpy as np
import json
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
import datasets
from transformers import AutoTokenizer, logging
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from peft import LoraConfig
import bitsandbytes as bnb
logging.set_verbosity(logging.CRITICAL)

In [None]:
with open("keywords.json", 'r') as f:
  keywords = json.loads(f.read())

In [None]:
model_id = "NousResearch/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, padding=True, padding_side='left')
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['q_proj', 'v_proj']
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id, peft_config = lora_config, load_in_4bit=True, bnb_4bit_quant_type='nf4')
id_eot = tokenizer.convert_tokens_to_ids(["<|eot_id|>"])[0]
tokenizer.pad_token = tokenizer.eos_token

def generate_policy_response(template):
    template+="<|start_header_id|>assistant<|end_header_id|>\n\n"
    inp_ids = tokenizer(template, return_tensors="pt").to("cuda")
    out_ids = model.generate(**inp_ids,max_new_tokens=20).squeeze()
    start_gen = inp_ids.input_ids.shape[1]
    out_ids = out_ids[start_gen:]
    if id_eot in out_ids:
        stop = out_ids.tolist().index(id_eot)
        out = tokenizer.decode(out_ids[:stop])
    else:
        out = tokenizer.decode(out_ids)
    return out

In [None]:
class Environment():
    def __init__(self) -> None:
        self.questions = []
        self.answers = []
        self._select_keyword()

        self.question_state = """<|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant, and your are very smart in playing 20 questions game,
the user is going to think of a word, it can be only one of the following 3 categories:
1. a place
2. a person
3. a thing
So focus your area of search on these options. and give smart questions that narrows down the search space\n
your role is to find the word by asking him up to 20 questions, your questions to be valid must have only a 'yes' or 'no' answer.
the user has chosen the word, start by asking your question!
please be short and not verbose, output only your question, no extra word!<|eot_id|>
"""

        self.guess_state = """<|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant, and your are very smart in playing 20 questions game,
the user is going to think of a word, it can be only one of the following 3 categories:
1. a place
2. a person
3. a thing
So focus your area of search on these options. \n
based on the following conversation, can you guess the word, please give only the word, no verbosity around\n<|eot_id|>
"""


    def _select_keyword(self):
        category = np.random.choice(keywords)
        chosen_key = np.random.choice(category['words'])
        self.keyword = [chosen_key['keyword'].lower()]+[x.lower() for x in chosen_key['alts']]
        # self.key_text = f"""{self.keyword[0]} {', also known as '+self.keyword[1] if len(self.keyword)>1 else''} {''.join([' and '+self.keyword[i] for i in self.keyword])}"""
        self.category = category['category']


    def step(self, question: str):
        """Given the question (action) the LLM (policy) will respond with yes or no.
        using the same LLM that is being trained as policy to save memory."""

        self.questions.append(question)

        # step 1 answer
        sys_prompt = f"""you are a helpful AI assistant, and your are very smart in playing 20 questions game,
        the role of the user is to guess the word by asking you up to 20 questions, your answers to be valid must be a 'yes' or 'no', any other answer is invalid and you lose the game.
        Know that the user will always guess a word belonging to one of the following 3 categories:
        1. a place
        2. a person
        3. a thing
        so make sure you understand the user's question and you understand the keyword you're playig on.
        for now the word that the user should guess is: "{self.keyword[0]}" {', also known as '+self.keyword[1] if len(self.keyword)>1 else''} {''.join([' and '+x for x in self.keyword[2:]]) if len(self.keyword)>2 else ''},
        it is of category "{self.category}",
        to help you, here's an example of how it should work assuming that the keyword is Morocco in the category "place":
        example:
        <user: is it a place?
        you: yes
        user: is it in europe?
        you: no
        user: is it in africa?
        you: yes
        user: do most people living there have dark skin?
        you: no
        user: is it a country name starting by m ?
        you: yes
        user: is it Morocco?
        you: yes.>"""

        chat_template = f"""<|start_header_id|>system<|end_header_id|>\n\n{sys_prompt}<|eot_id|>"""
        chat_template += "<|start_header_id|>user<|end_header_id|>\n\n"
        chat_template += f"{question}<|eot_id|>"

        output = generate_policy_response(chat_template)
        self.answers.append(output)

        # Step 2 Generate next state
        self.question_state += f"<|start_header_id|>assistant<|end_header_id|>\n\n{question}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n"
        self.question_state += f"{output}<|eot_id|>\n"
        # feed this state to policy to generate next question

        # Step 3 Generate guess state
        ## Experiment: Should we keep generating questions even thou we could guess correctly?
        self.guess_state = self.guess_state[:-11] + f"""Question: {question}\nAnswer: {output}\n""" +"""<|eot_id|>\n"""

        return


In [None]:
def data_generator(batch_size: int):

    X = []
    X_key_categories = [] # store ([keyword], category, no_of_qs, type: 'q'-> question, 'g'-> guess) that was used to create respective sample in X.
    env = Environment()

    for i in range(batch_size//2):
        # Step 1 generate a question
        # add from 0 to 19 of 20 questions in question state
        X.append(env.question_state)
        X_key_categories.append(json.dumps([env.keyword, env.category, len(env.questions)-1, 'q']))
        generated_question = generate_policy_response(env.question_state)
        for z in range(10):
            if bool(generated_question.replace(' ', ''))==False:
                generated_question = generate_policy_response(env.question_state)
            else:
                break

        # make new question and guess state
        env.step(generated_question)
        # Add from 1 to 20 questions here.
        X.append(env.guess_state)

        X_key_categories.append(json.dumps([env.keyword, env.category, len(env.questions), 'g']))

        if len(env.questions) == 20:
            env = Environment()

    return X, X_key_categories



In [None]:
def reward(X, y, X_key_categories, r = [15.5, 15.25, 14.975, 14.672, 14.34, 13.974,
                      13.571, 13.128, 12.641, 12.105, 11.516, 10.867,
                      10.154, 9.369, 8.506, 7.557, 6.513, 5.364, 4.1, 2.71, 1.181]):
    """generates reward for two types of inputs

    X_key_categories: ([keyword], category, no_of_qs, type: 'q'-> question, 'g'-> guess)

    type = 'q': attaches the generated question (y) to respective X and use it to generate a guess from LLM policy.
            If the guess is correct it generates a reward according to reward policy.
    type = 'g': generates a reward if the guess is correct.

    Reward policy:
        r = a(b^x)+c, values used are a = -2.5, b = 1.1, c = 18
    """

    rewards = []

    for i in range(len(X)):
        X_key_categories[i] = json.loads(X_key_categories[i])
        if X_key_categories[i][3]=='g':
            if y[i].lower() in X_key_categories[i][0]:
                rewards.append(torch.tensor(r[X_key_categories[i][2]]))
            elif X_key_categories[i][2]==20:
                rewards.append(torch.tensor(-7.5))
            else:
                rewards.append(torch.tensor(0.0))
        else:
            sys_prompt = f"""you are a helpful AI assistant, and your are very smart in playing 20 questions game,
            the role of the user is to guess the word by asking you up to 20 questions, your answers to be valid must be a 'yes' or 'no', any other answer is invalid and you lose the game.
            Know that the user will always guess a word belonging to one of the following 3 categories:
            1. a place
            2. a person
            3. a thing
            so make sure you understand the user's question and you understand the keyword you're playig on.
            for now the word that the user should guess is: "{X_key_categories[i][0][0]}" {', also known as '+X_key_categories[i][0][1] if len(X_key_categories[i][0])>1 else''} {''.join([' and '+m for m in X_key_categories[i][0][2:]]) if len(X_key_categories[i][0])>2 else ''},
            it is of category "{X_key_categories[i][1]}",
            to help you, here's an example of how it should work assuming that the keyword is Morocco in the category "place":
            example:
            <user: is it a place?
            you: yes
            user: is it in europe?
            you: no
            user: is it in africa?
            you: yes
            user: do most people living there have dark skin?
            you: no
            user: is it a country name starting by m ?
            you: yes
            user: is it Morocco?
            you: yes.>"""

            chat_template = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{sys_prompt}<|eot_id|>"""
            chat_template += "<|start_header_id|>user<|end_header_id|>\n\n"
            chat_template += f"{y[i]}<|eot_id|>"

            output = generate_policy_response(chat_template)

            prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

            You are a helpful AI assistant, and your are very smart in playing 20 questions game,
            the user is going to think of a word, it can be only one of the following 3 categories:
            1. a place
            2. a person
            3. a thing
            So focus your area of search on these options. \n
            based on the following conversation, can you guess the word, please give only the word, no verbosity around
            """
            # Extracting previous questions from X
            for text in X[i].split('<|start_header_id|>assistant<|end_header_id|>')[1:-1]:
                q,a = text.split('<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n')
                q = q.lstrip().replace('\n', '')
                a = a.replace('\n', '').replace('<|eot_id|>', '')
                prompt+=f"""Question: {q}\nAnswer: {a}\n"""

            prompt+=f"""Question: {y[i]}\nAnswer: {output}\n"""

            guess_word = generate_policy_response(prompt)
            if guess_word.lower() in X_key_categories[i][0]:
                rewards.append(torch.tensor(r[X_key_categories[i][2]]))
            elif X_key_categories[i][2]>=19:
                rewards.append(torch.tensor(-7.5))
            else:
                rewards.append(torch.tensor(0.0))

    return rewards


In [None]:
class DataGenerator(Dataset):
    def __init__(self, num_records=512):
        super().__init__()
        x, metadata = data_generator(num_records)
        input_ids = tokenizer(x, return_tensors="pt", padding=True, truncation=True)['input_ids']
        bos_token = torch.zeros((input_ids.shape[0],1), dtype=torch.int64)+tokenizer.bos_token_id
        assistant_id = tokenizer("<|start_header_id|>assistant<|end_header_id|>\n\n", return_tensors="pt")['input_ids'].squeeze()[1:]
        assistant_id_tensor = torch.zeros((input_ids.shape[0], assistant_id.shape[0]), dtype=torch.int64)+assistant_id

        dataset = datasets.Dataset.from_dict({'query': x, 'metadata': metadata,
                                             'input_ids': torch.concat((bos_token, input_ids, assistant_id_tensor), dim=1)})
        del input_ids, assistant_id, assistant_id_tensor


        def tokenize_this(sample):
            sample['input_ids'] = tokenizer.encode(sample['query'])
            return sample

        dataset.set_format(type="torch")

        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]

In [None]:
# PPO Trainer

config = PPOConfig(
    model_name=model_id,
    learning_rate=1.41e-5,
    batch_size=8,
    mini_batch_size=1,
    optimize_device_cache=True,
)

generation_kwargs = {
    "min_length": -1,
    "max_new_tokens": 30,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}
epochs = 10
stats = []

for i in tqdm(range(epochs), "epoch: "):
    epoch_reward = 0
    dataset = DataGenerator(100)

    ppo_trainer = PPOTrainer(
        model=model,
        config=config,
        dataset=dataset,
        tokenizer=tokenizer
    )
    device = ppo_trainer.accelerator.device
    if ppo_trainer.accelerator.num_processes == 1:
        device = 0 if torch.cuda.is_available() else 'cpu'

    for batch in tqdm(ppo_trainer.dataloader):

        query_tensors = [batch['input_ids'][i,:].to("cuda") for i in range(batch['input_ids'].shape[0])]

        #### Get response from SFTModel
        # response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
        # batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
        response_tensors = [ppo_trainer.generate(x, **generation_kwargs).squeeze() for x in query_tensors]
        start_gen = batch['input_ids'].shape[1]
        batch['response'] = []
        for x in response_tensors:
            out_ids = x[start_gen:]
            if id_eot in out_ids:
                stop = out_ids.tolist().index(id_eot)
                out = tokenizer.decode(out_ids[:stop])
            else:
                out = tokenizer.decode(out_ids)
            batch["response"].append(out)

        #### Compute reward score
        rewards = reward(batch['query'], batch['response'], batch['metadata'])

        #### Run PPO step
        st = ppo_trainer.step(query_tensors, response_tensors, rewards)
        # ppo_trainer.log_stats(stats, batch=[], rewards=rewards)
        stats.append({"epoch": i,
                      "stats": st,
                      "rewards": rewards,
                      "batch_query": batch['query'],
                      "batch_response": batch['response']})
        epoch_reward += sum(rewards)
    print(epoch_reward)

#### Save model
ppo_trainer.save_pretrained("my_ppo_model_v2")
# model.save_pretrained("my_ppo_model_v2_model")
