<a href="https://colab.research.google.com/github/tttequila/Kaggle_20Q/blob/main/7B_mul_CoT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Configuring your kaggle token, see more details in the [**Configure your API key**](https://ai.google.dev/gemma/docs/setup) section

### set up env

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%%bash
mkdir ~/.kaggle
# change the first path to your path of kaggle.json
cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
chmod 600 ~/.kaggle/kaggle.json

In [3]:
%%bash
pip install -q -U torch immutabledict sentencepiece

     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 9.1 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.3/21.3 MB 55.2 MB/s eta 0:00:00


In [4]:
!git clone https://github.com/google/gemma_pytorch.git

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 231, done.[K
remote: Counting objects: 100% (115/115), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 231 (delta 83), reused 52 (delta 52), pack-reused 116[K
Receiving objects: 100% (231/231), 2.17 MiB | 21.59 MiB/s, done.
Resolving deltas: 100% (132/132), done.


### set up Gemma lib

In [5]:
# login the kaggle (need to store you kaggle.json to your google dirve)
import kagglehub
kagglehub.login()

import sys
sys.path.append("gemma_pytorch/gemma")
sys.path.append("gemma_pytorch")
import contextlib, torch


VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

In [6]:
from gemma.config import get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM, GemmaModel

In [7]:
# Choose variant and machine type
VARIANT = '7b-it-quant'
MACHINE_TYPE = 'cuda'

In [8]:
import torch
import gemma
import itertools
from typing import Iterable
from typing import Any, List, Optional, Sequence, Tuple, Union
import os

In [9]:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)


def interleave_unequal(x, y):
    '''
        Interleave two lists of unequal length.
    '''
    return [
        item for pair in itertools.zip_longest(x, y) for item in pair if item is not None
    ]

### Building Agents

<details>
  <summary> model </summary>
  
  - `self.forward()`: getting next token and corresponding logits
  - `self.generate()`: see summary below. May need to be rewriten if we wanna get the cumulative logits for the whole sentence  


</details>



<details>
  <summary> model.generate() </summary>
  
  - **prompts** | `Union[str, Sequence[str]]`: Your prompts
  - **device** | `Any`: Devices
  - **output_len** | `int`: max output length
  - **temperature** | `Union[float, None]`: temperature degree, controlling how variant its response could be  
  - **top_p** | `float`:
  - **top_k** | `int`:

  regarding temperature, top_p and top_k, check this [link](https://blog.csdn.net/REfusing/article/details/137866583)

</details>

#### Define Formatter


In [10]:
from typing import Iterable
import itertools

class PromptFormatter:

    '''
        formatter class to format the prompt text for the model.
        A general idea is
    '''

    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'

    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._template_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._template_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        # self._all_prompt = ''
        self.reset()

    def __repr__(self):
        return self._all_prompt

    def reset(self):
        self._all_prompt = ''
        # if system prompt is provided, add it to the prompt
        if self._system_prompt:
            self._all_prompt += self._template_user.format(self._system_prompt)
        # same for few shot examples
        if self._few_shot_examples:
            self.add_rounds(self._few_shot_examples, start_agent='user')

    def add_user_round(self, user_prompt: str):
        # add user round to the prompt
        self._all_prompt += self._template_user.format(user_prompt)

    def add_agent_round(self, model_response: str):
        self._all_prompt += self._template_model.format(model_response)

    def add_rounds(self, rounds: Iterable, start_agent: str):
        '''
            Apply a sequence of rounds to the formatter, starting with the specified agent.
        '''
        formatters = [self.add_agent_round, self.add_user_round] if start_agent == 'model' else [self.add_user_round, self.add_agent_round] # here, self.model and self.user are functions definded above
        formatters = itertools.cycle(formatters)
        for fmt, round in zip(formatters, rounds):
            fmt(round)
        return self

    # def add_end_token(self):
    #     self._all_prompt += f"{self._end_token}\n"

    def add_new_round(self, player: str, prompt:str = None, end_token: bool = False):
        self._all_prompt += f"{self._start_token}{player}\n"
        if prompt:
            self._all_prompt += f'{prompt}'
        if end_token:
            # self.add_end_token()
            self._all_prompt += f"{self._end_token}\n"

    def formate_MCQA(self):
        raise NotImplementedError

sys_prompt = 'You are a highly knowledgeable naturalist with extensive knowledge about objects, places, and people around the world, almost like a search engine. You also possess strong reasoning abilities, allowing you to deduce the answer to your query using existing knowledge and information. \
You need to utilize your knowledge and reasoning skills to play a game of 20 questions. \
In each round of the game, choose the binary attribute most likely to aid your deduction, such as "a cross on the flag." \
Using this attribute, ask a question in English. Generally, the most helpful attribute provides the maximum information gain, meaning it significantly reduces the entropy of your guess, regardless of whether the answer is **yes** or **no**. For example, if you believe asking about the color is most helpful for your deduction, you should ask, "Is it a thing with a red color?" \
And after your thorough thinking, you need to denote the key binary attribute with double stars. Here are few examples of how to use Chain of thought:'

few_shot_examples = [

# 'Now give me an example of thinking with the Chain of Thought',

# 'Given Information: ["is a country", "is in Europe", "its flag is with red color"] \
# Considerations: \
# 1. Geographical Location: European countries can be divided into regions such as Northern Europe, Southern Europe, Western Europe, and Eastern Europe. This division can effectively narrow down the options. \
# 2. Population and Area: Considering population and area can also eliminate some countries. \
# 3. Other Colors on the Flag: We know the flag has red; we can further narrow down by asking about other colors. \
# 4. Language: European countries have various languages; asking about the official language can narrow it further. \
# Reasoning: Southern Europe includes a few countries such as Spain, Italy, Portugal, Greece, etc. Determining if it is in Southern Europe can immediately eliminate about three-quarters of European countries. \
# Compared to questions about flag colors or language, geographical location questions are more evenly distributed, and a **yes** or **no** answer will significantly narrow down the range. \
# Possible Next Steps: \
# If the answer is **yes**, we can further inquire about other notable characteristics of these countries, such as language or history. \
# If the answer is **no**, we can turn our attention to Western, Eastern, or Northern Europe and continue narrowing down the options.\
# **Question:** "Is this country in Southern Europe?"',

'Now let us play a game of 20 questions. You will be the questioner, and I will be the answerer.',

'It is the first round, so i will randomly choose a attribute from ["country", "city", "landmark"] as the first attribut. Here I choose country as the first attribute.\
**Question:** "Is it a country?"',

'**no**',

'Given information: ["It is not a country", "It is either a landmark or a city"] \
Considerations: \
1. Type of Entity: Determining whether it is a city or a landmark can greatly narrow down our options. \
2. Geographical Location: If we can determine the continent, we can significantly reduce the number of possible cities or landmarks. \
3. Cultural/Historical Significance: Asking about cultural or historical significance can also be helpful, but it might be too broad at this stage. \
Reasoning: To effectively narrow it down, the most efficient next step is to determine if it is a city or a landmark. By determining if it is a city, we can immediately focus on known cities globally, which are fewer and more well-documented compared to all possible landmarks. \
If it is not a city, we can then focus on landmarks, which might require more specific questions but will also significantly narrow our scope. \
Possible Next Steps: \
If the answer is **yes**, we can then ask about the continent or specific features of the city. \
If the answer is **no**, we can ask about the type of landmark or its location. \
**Question:** "Is it a city?"',

'**no**',

# 'Given information: ["It is not a country", "It is either a landmark or a city", "It is not a landmark"] \
# Considerations: \
# Alright, so it is not a city. Now we know it must be a landmark. Given this, our goal is to identify the type or location of the landmark. \
# 1. Type of Landmark: There are various types such as natural landmarks (mountains, rivers), man-made landmarks (statues, buildings), or historical sites. \
# 2. Geographical Location: Identifying the continent or region can greatly narrow down the possibilities. \
# To get a high information gain, let us narrow down the type of landmark first. \
# Reasoning: Landmarks can be broadly categorized into man-made or natural. This binary distinction will help us immediately focus on a more specific set of possibilities. \
# Man-made landmarks include famous buildings, statues, monuments, etc., while natural landmarks include mountains, rivers, and natural formations. \
# Possible Next Steps: \
# If the answer is **yes**, we can then ask about the purpose or era of the man-made structure. \
# If the answer is **no**, we can focus on identifying the type of natural landmark or its location. \
# **Question:** "Is it a man-made landmark?"',

# '**no**',

'Given information: ["It is not a country", "It is either a landmark or a city", "It is not a landmark", "It is not a man-made landmark"] \
Considerations: \
1. Type of Natural Landmark: Common natural landmarks include mountains, rivers, lakes, forests, deserts, etc. \
2. Geographical Location: Identifying the continent or region can still be very helpful. \
To further narrow down, let us focus on the type of natural landmark. \
Reasoning: River are one of the most well-known types of natural landmarks. \
Possible Next Steps: \
If the answer is **yes**, we can then focus on identifying which river by asking about its specific features or location. \
If the answer is **no**, we can eliminate a significant category and focus on other types of natural landmarks like mountains, lakes, deserts, or forests. \
**Question:** "Is it a river?"',

'**yes**',

# 'Given information: ["It is not a country", "It is either a landmark or a city", "It is not a landmark", "It is a river"] \
# Considerations: \
# Great, so it is a river. Now, we need to narrow down which river it could be. To do this effectively, identifying the continent or region will significantly help. \
# 1. Geographical Location: Identifying the continent or region can narrow down the list of possible rivers considerably. \
# 2. Length or Size: Asking about whether it is among the longest rivers might also help. \
# 3. Famous Rivers: There are some rivers that are globally well-known due to their significance. \
# To achieve a high information gain, let us focus on the geographical location. \
# Reasoning: Asia has some of the world us most significant and longest rivers, such as the Yangtze, Ganges, and Mekong. \
# Possible Next Steps: \
# If the answer is **yes**, we can focus on Asian rivers specifically. \
# If the answer is **no**, we can eliminate Asia and consider other continents like Africa, Europe, South America, or North America. \
# **Question:** "Is this reiver located in Asia?"',

# '**yes**',

'Given information: ["It is not a country", "It is either a landmark or a city", "It is not a landmark", "It is a river", "It is located in Asia"] \
Excellent, the river is located in Asia. This narrows down our options significantly. Now, we need to focus on identifying which river it could be. \
Considerations: \
1. Famous Asian Rivers: Asia has several famous rivers, such as the Yangtze, Ganges, Mekong, Indus, and Yellow River. \
2. Geographical Features: We could ask about the river\'s length, the countries it flows through, or any significant historical or cultural associations. \
3. To narrow it down further, let us focus on whether it is one of the most well-known rivers. \
Reasoning: These three rivers are among the longest and most significant in Asia. \
Possible Next Steps: \
If the answer is **yes**, we can focus on distinguishing between these three. \
If the answer is **no**, we can focus on other significant rivers like the Ganges, Indus, or Brahmaputra. \
**Question:** "Is this river one of the three longest rivers in Asia (Yangtze, Yellow, or Mekong)?"',

'**yes**',

# 'Given information: ["It is not a country", "It is either a landmark or a city", "It is not a landmark", "It is a river", "It is located in Asia", "It is one of the three longest rivers in Asia"] \
# Great, we now know it is one of the three longest rivers in Asia: the Yangtze, Yellow, or Mekong. Let us narrow it down further by focusing on specific characteristics of these rivers. \
# Considerations: \
# 1. Geographical Location: We can ask about the countries the river flows through. \
# 2. Cultural/Historical Significance: Each of these rivers has distinct cultural and historical significance. \
# To narrow it down further, let us ask about the country. \
# Reasoning: Both the Yangtze and Yellow rivers flow through China, while the Mekong flows through several countries including China, but predominantly in Southeast Asia. \
# Possible Next Steps: \
# If the answer is yes, it will most likely be the Yangtze or Yellow River. \
# If the answer is no, it must be the Mekong River. \
# **Question:** "Does this river flow through China?"',

# '**yes**',

'Given information: ["It is not a country", "It is either a landmark or a city", "It is not a landmark", "It is a river", "It is located in Asia", "It is one of the three longest rivers in Asia", "It flows through China"] \
Great, so the river flows through China. This means it is either the Yangtze River or the Yellow River. To determine which one it is, we should focus on distinguishing features. \
Considerations: \
1. Length: The Yangtze is the longest river in Asia and the third-longest in the world. \
2. Color/Name: The Yellow River is named for its yellowish color due to loess sediments. \
Reasoning: The Yangtze River is the longest river in Asia. \
Possible Next Steps: \
If the answer is **yes**, it is the Yangtze River. \
If the answer is **no**, it must be the Yellow River. \
**Question:** "Is this river the longest river in Asia?"',

'**yes**',

'It is the Yangtze River',

'Correct!']

print(str(PromptFormatter(sys_prompt, few_shot_examples)))

<start_of_turn>user
You are a highly knowledgeable naturalist with extensive knowledge about objects, places, and people around the world, almost like a search engine. You also possess strong reasoning abilities, allowing you to deduce the answer to your query using existing knowledge and information. You need to utilize your knowledge and reasoning skills to play a game of 20 questions. In each round of the game, choose the binary attribute most likely to aid your deduction, such as "a cross on the flag." Using this attribute, ask a question in English. Generally, the most helpful attribute provides the maximum information gain, meaning it significantly reduces the entropy of your guess, regardless of whether the answer is **yes** or **no**. For example, if you believe asking about the color is most helpful for your deduction, you should ask, "Is it a thing with a red color?" And after your thorough thinking, you need to denote the key binary attribute with double stars. Here are few 

#### Define Agent

In [11]:

class GemmaAgent:

    def __init__(self, model_variant, device='cuda:0', env="kaggle", system_prompt=None, few_shot_examples=None):
        # model initialization
        self.device = device
        self.model_variant = model_variant

        WEIGHTS_PATH = self._set_up_env(env)

        # Ensure that the tokenizer is present
        tokenizer_path = os.path.join(WEIGHTS_PATH, 'tokenizer.model')
        assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

        # Ensure that the checkpoint is present
        ckpt_path = os.path.join(WEIGHTS_PATH , f'gemma-{model_variant}.ckpt')
        assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

        # loading model configuration
        model_config = get_config_for_2b() if "2b" in model_variant else get_config_for_7b()
        model_config.quant = "quant" in model_variant
        model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")

        with _set_default_tensor_type(model_config.get_dtype()):
            self.model = GemmaForCausalLM(model_config)
            self.model.load_weights(ckpt_path)
            self.model = self.model.to(self.device).eval()

        # agent args
        self.formatter = PromptFormatter(system_prompt=system_prompt,
                                         few_shot_examples=few_shot_examples)
        self.round_num = 0

    def _set_up_env(self, env):

        if env == 'kaggle':
            print("Loading model in Kaggle, model weights will be searched within local directories.")

            # kaggle configuration
            KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
            if os.path.exists(KAGGLE_AGENT_PATH):
                WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, f"gemma/pytorch/{self.model_variant}/2")
            else:
                WEIGHTS_PATH = f"/kaggle/input/gemma/pytorch/{self.model_variant}/2"

        elif env == 'colab':
            print("Loading model in Colab, starting from downloading the model weights.")

            WEIGHTS_PATH = kagglehub.model_download(f'google/gemma/pyTorch/{self.model_variant}')
        else:
            raise ValueError("Argument 'env' should be in ['kaggle', 'colab']")

        return WEIGHTS_PATH

    def response(self, obs, *args, **kwargs):
        '''
            The game env will call this function.
        '''

        # formatting prompt
        prompt = None

        # getting response from LLM
        response = self.model.generate(prompt, device=self.device, output_len=200)

### format your prompts

In [12]:
# # Chat templates
# USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
# MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# prompt = (
#     USER_CHAT_TEMPLATE.format(
#         prompt='What is a good place for travel in the US?'
#     )
# )
# print('Chat prompt:\n', prompt)

### Chat & Response

In [13]:
# del model
# del agent

In [14]:
# Choose variant and machine type
VARIANT = '7b-it-quant'
MACHINE_TYPE = 'cuda'

agent = GemmaAgent(model_variant=VARIANT, device=MACHINE_TYPE, env='colab',
                   system_prompt=sys_prompt, few_shot_examples=few_shot_examples)

Loading model in Colab, starting from downloading the model weights.
Attaching model 'google/gemma/pyTorch/7b-it-quant' to your Colab notebook...


In [15]:
import time
model = agent.model
agent.formatter.reset()
agent.formatter.add_new_round('user', 'Remember how to use CoT as above, now let us play a new game from the begining , you are the guesser, remember to identify your question to with **Question:** at the end', True)
agent.formatter.add_new_round('model','Is it a country?', True)
agent.formatter.add_new_round('user','**no**', True)
agent.formatter.add_new_round('model',None, False)
print(str(agent.formatter))
start = time.time()
response = model.generate(str(agent.formatter), device=agent.device, output_len=200)
print(f"[response in {(time.time()-start):.2f}s]\n{response}")

<start_of_turn>user
You are a highly knowledgeable naturalist with extensive knowledge about objects, places, and people around the world, almost like a search engine. You also possess strong reasoning abilities, allowing you to deduce the answer to your query using existing knowledge and information. You need to utilize your knowledge and reasoning skills to play a game of 20 questions. In each round of the game, choose the binary attribute most likely to aid your deduction, such as "a cross on the flag." Using this attribute, ask a question in English. Generally, the most helpful attribute provides the maximum information gain, meaning it significantly reduces the entropy of your guess, regardless of whether the answer is **yes** or **no**. For example, if you believe asking about the color is most helpful for your deduction, you should ask, "Is it a thing with a red color?" And after your thorough thinking, you need to denote the key binary attribute with double stars. Here are few 

### applying formatter

In [16]:
# dummy_gemma = lambda x: "A4"

# rounds = ['Q1', 'A1', 'Q2', 'A2', 'Q3', 'A3']

# system_prompt = "You are an AI assistant designed to play the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific person, place, or thing."

# few_shot_examples = [
#     "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
#     "Is it a person?", "**no**",
#     "Is is a place?", "**yes**",
#     "Is it a country?", "**yes** Now guess the keyword.",
#     "**France**", "Correct!",
# ]

# dummy_formatter = PromptFormatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)

# print(str(dummy_formatter), end='\n\n')
# dummy_formatter.add_existing_rounds(rounds, start_agent='user')
# print(str(dummy_formatter), end='\n\n')
# dummy_formatter.add_new_round('user', prompt='Q4', end_token=False)
# print(str(dummy_formatter), end='\n\n')
# dummy_formatter.add_end_token()
# print(str(dummy_formatter), end='\n\n')
# dummy_formatter.add_new_round('model', end_token=False)
# print(str(dummy_formatter), end='\n\n')
# dummy_formatter.add_new_round('model', prompt=dummy_gemma(0), end_token=True)
# print(str(dummy_formatter), end='\n\n')