In [None]:
# import pandas as pd
# import json

# # load the keywords
# from importlib.machinery import SourceFileLoader
# keywords = SourceFileLoader("keywords",'/kaggle/input/llm-20-questions/llm_20_questions/keywords.py').load_module()
# df = json.loads(keywords.KEYWORDS_JSON)

# words = []
# for cat in df:
#     print(cat['category'], len(cat['words']))
#     for w in cat['words']:
#         words.append([cat['category'], w["keyword"], w["alts"], "", "", 0.0, 0.0])

# # convert the keywords into the dataframe
# df = pd.DataFrame(words, columns=['Category','Word','Alternatives', "Cat1", "Cat2", "Lat", "Lon"])
# df.head()

In [None]:
%%bash
cd /kaggle/working
rm -rf ./*
pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece
git clone https://github.com/google/gemma_pytorch.git > /dev/null
mkdir /kaggle/working/submission/lib/gemma/
mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/

In [None]:
%%writefile submission/main.py
# # Setup
import os
import sys

KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
if os.path.exists(KAGGLE_AGENT_PATH):
    sys.path.insert(0, os.path.join(KAGGLE_AGENT_PATH, 'lib'))
else:
    sys.path.insert(0, "/kaggle/working/submission/lib")

import contextlib
import os
import sys
from pathlib import Path

import torch
from gemma.config import get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM

if os.path.exists(KAGGLE_AGENT_PATH):
    WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "gemma/pytorch/7b-it-quant/2")
else:
    WEIGHTS_PATH = "/kaggle/input/gemma/pytorch/7b-it-quant/2"


import itertools
from typing import Iterable


class GemmaFormatter:
    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'

    # Initilatization
    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._turn_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._turn_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        self.reset()

    # Returns the current state of the conversation as a string.
    def __repr__(self):
        return self._state

    # Methods of Adding Turns - Adds a user prompt to the conversation
    def user(self, prompt):
        self._state += self._turn_user.format(prompt)
        return self
        
    # Methods of Adding Turns - Adds a model response to the conversation.
    def model(self, prompt):
        self._state += self._turn_model.format(prompt)
        return self

    # Methods of Adding Turns -  Marks the start of a user turn
    def start_user_turn(self):
        self._state += f"{self._start_token}user\n"
        return self

    # Methods of Adding Turns -  Marks the start of a model turn.
    def start_model_turn(self):
        self._state += f"{self._start_token}model\n"
        return self

    # Methods of Adding Turns -  Marks the end of the current turn
    def end_turn(self):
        self._state += f"{self._end_token}\n"
        return self

    # Reset Method
    def reset(self):
        # Initializes `_state` to an empty string
        self._state = ""  

        # Adds the system prompt if provided.
        if self._system_prompt is not None:
            self.user(self._system_prompt)  
            
        # Applies few-shot examples if provided
        if self._few_shot_examples is not None: 
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self


import re

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Set the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

from pydantic import BaseModel
    
class Observation(BaseModel):
    step: int
    role: t.Literal["guesser", "answerer"]
    turnType: t.Literal["ask", "answer", "guess"]
    keyword: str
    category: str
    questions: list[str]
    answers: list[str]
    guesses: list[str]
    
    @property
    def empty(self) -> bool:
        return all(len(t) == 0 for t in [self.questions, self.answers, self.guesses])
    
    def get_history(self) -> t.Iterator[tuple[str, str, str]]:
        return itertools.zip_longest(self.questions, self.answers, self.guesses, fillvalue="[none]")

    def get_history_text(self, *, skip_guesses: bool = False, skip_answer: bool = False, skip_question: bool = False) -> str:
        if not self.empty:
            history = "\n".join(
            f"""{'**Question:** ' + question if not skip_question else ''} -> {'**Answer:** ' + answer if not skip_answer else ''}
            {'**Previous Guess:** ' + guess if not skip_guesses else ''}
            """
            for i, (question, answer, guess) in enumerate(self.get_history())
            )
            return history
        return "none yet."


class GemmaAgent:
    def __init__(self, variant='7b-it-quant', device='cuda:0', system_prompt=None, few_shot_examples=None):
        self._variant = variant
        self._device = torch.device(device)
        self.formatter = GemmaFormatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)

        print("Initializing model")
        
        # Model Config.
        model_config = get_config_for_2b() if "2b" in variant else get_config_for_7b()
        model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")
        model_config.quant = "quant" in variant
        
        # Model.
        with _set_default_tensor_type(model_config.get_dtype()):
            """TODO: Change to different models to see the performance"""
            model = GemmaForCausalLM(model_config)
            ckpt_path = os.path.join(WEIGHTS_PATH , f'gemma-{variant}.ckpt')
            model.load_weights(ckpt_path)
            self.model = model.to(self._device).eval()

    def __call__(self, obs, *args):
        self._start_session(obs)
        prompt = str(self.formatter)
        response = self._call_llm(prompt)
        response = self._parse_response(response, obs)
        print(f"{response=}")
        return response

    def _start_session(self, obs: dict):
        raise NotImplementedError

    def _call_llm(self, prompt, max_new_tokens=32, **sampler_kwargs):
        """TODO: change the parameters of the LLM"""
        if sampler_kwargs is None:
            sampler_kwargs = {
                'temperature': 0.01,
                'top_p': 0.1,
                'top_k': 1,
        }
        response = self.model.generate(
            prompt,
            device=self._device,
            output_len=max_new_tokens,
            **sampler_kwargs,
        )
        return response

    def _parse_keyword(self, response: str):
        match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
        if match is None:
            keyword = ''
        else:
            keyword = match.group().lower()
        return keyword

    def _parse_response(self, response: str, obs: dict):
        raise NotImplementedError


def interleave_unequal(x, y):
    return [
        item for pair in itertools.zip_longest(x, y) for item in pair if item is not None
    ]

"""TODO: Adjust the questioner"""
class GemmaQuestionerAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user("Let's play 20 Questions. You are playing the role of the Questioner.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='model')
        hints = self.turns_to_hints(turns)
        if obs.turnType == 'ask':
            if len(turns) == 0:
                self.formatter.user("Please ask a yes-or-no question.")
            else:
                self.formatter.user(f"Please ask a yes-or-no question based on the \
                                    hints of the keyword in this list: {hints}")
        elif obs.turnType == 'guess':
            self.formatter.user("Now guess the keyword. Surround your guess with double asterisks.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        if obs.turnType == 'ask':
            match = re.search(".+?\?", response.replace('*', ''))
            if match is None:
                question = "Is it a person?"
            else:
                question = match.group()
            return question
        elif obs.turnType == 'guess':
            guess = self._parse_keyword(response)
            return guess
        else:
            raise ValueError("Unknown turn type:", obs.turnType)
    
    # transform the turns such that the questions are converted to affirmative or 
    # negative statements based on the corresponding answers
    def turns_to_hints(self, turns):
        hints = []
        for i in range(0, len(turns), 2):
            question = turns[i]
            answer = turns[i+1]
            # hints.append(f"Question: {question} Answer: {answer}")
            
            if question.startswith('is it'):
                item = question[6:-1]  # Extracts the part of the question after "is it " and before "?"
                if answer == 'yes':
                    transformed_statement = f"it is {item}"
                else:
                    transformed_statement = f"it is not {item}"
            elif question.startswith('does'):
                item = question[5:-1]  # Extracts the part of the question after "does " and before "?"
                if answer == 'yes':
                    transformed_statement = f"it does {item}"
                else:
                    transformed_statement = f"it does not {item}"
            else:
                transformed_statement = f"Unrecognized question format: {question}"

            hints.append(transformed_statement)

        return hints


class GemmaAnswererAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user(f"Let's play 20 Questions. You are playing the role of the Answerer. The keyword is {obs.keyword} in the category {obs.category}.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='user')
        self.formatter.user(f"The question is about the keyword {obs.keyword} in the category {obs.category}. Give yes-or-no answer and surround your answer with double asterisks, like **yes** or **no**.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        answer = self._parse_keyword(response)
        return 'yes' if 'yes' in answer else 'no'

# Agent Creation
system_prompt = """You are an AI assistant designed to play the 20 Questions game. 
In this game, the Answerer thinks of a keyword and responds to yes-or-no questions 
by the Questioner. The keyword is a specific person, place, or thing. Your objective 
is to guess the secret keyword in as few turns as possible and win."""

few_shot_examples = [
    "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
    "Is it a person?", "**no**",
    "Is is a place?", "**yes**",
    "Is it a country?", "**yes** Now guess the keyword.",
    "**France**", "Correct!",
]

# **IMPORTANT:** Define agent as a global so you only have to load
# the agent you need. Loading both will likely lead to OOM.

# Initialize agent variable
agent = None

# Function to get the appropriate agent based on the name
def get_agent(name: str):
    global agent
    
    # If agent is not initialized and the requested agent is a "questioner"
    if agent is None and name == 'questioner':
        # Initialize GemmaQuestionerAgent with specific parameters
        agent = GemmaQuestionerAgent(
            device='cuda:0',  # Device for computation
            system_prompt=system_prompt,  # System prompt for the agent
            few_shot_examples=few_shot_examples,  # Examples to guide the agent's behavior
        )
    # If agent is not initialized and the requested agent is an "answerer"
    elif agent is None and name == 'answerer':
        # Initialize GemmaAnswererAgent with the same parameters
        agent = GemmaAnswererAgent(
            device='cuda:0',
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples,
        )
    
    # Ensure that the agent is initialized
    assert agent is not None, "Agent not initialized."

    # Return the initialized agent
    return agent

# Function to handle interactions based on observations
def agent_fn(obs, cfg):
    # If observation is for asking a question
    if obs.turnType == "ask":
        # Get the "questioner" agent to respond to the observation
        response = get_agent('questioner')(obs)
    # If observation is for making a guess
    elif obs.turnType == "guess":
        # Get the "questioner" agent to respond to the observation
        response = get_agent('questioner')(obs)
    # If observation is for providing an answer
    elif obs.turnType == "answer":
        # Get the "answerer" agent to respond to the observation
        response = get_agent('answerer')(obs)
    
    # If the response from the agent is either None or very short
    if response is None or len(response) <= 1:
        # Assume a positive response ("yes")
        return "yes"
    else:
        # Return the response received from the agent
        return response



In [None]:
!apt install pigz pv > /dev/null
# # !tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/gemma/pytorch/7b-it-quant/2

!tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/ gemma/pytorch/7b-it-quant/2

In [None]:
# class obs(object):
#     '''
#     This class represents an observation from the game. 
#     Make sure that turnType is from ['ask','answer','guess']
#     Choose whatever keyword you'd like to test.
#     Choose the category of the keyword.
#     Questions can be a list of strings. If you have pass a history of questions to your agent.
#     Answers can be a list of strings.
#     Use response to pass the question to the answerer agent.
#     '''
#     def __init__(self, turnType, keyword, category, questions, answers):
#         self.turnType = turnType
#         self.keyword = keyword
#         self.category = category
#         self.questions = questions
#         self.answers = answers

# #example of obs, to be passed to agent. Pass your own obs here.
# obs1 = obs(turnType='ask', 
#           keyword='paris', 
#           category='place', 
#           questions=['is it a person?', 'is it a place?', 'is it a country?', 'is it a city?'], 
#           answers=['no', 'yes', 'no', 'yes'])

# question = agent_fn(obs1, None)
# print(question)

# #you should get the response from your agent as the output. 
# #this does not help with submission errors but will help you evaluate your agent's responses. 